diff --git a/serialx/common.py b/serialx/common.py index 87e0f52..c029c1d 100644 --- a/serialx/common.py +++ b/serialx/common.py @@ -896,6 +896,13 @@ def __init__( self._serial: BaseSerial | None = None self._closing: bool = False self._closed_waiter: asyncio.Future[None] = loop.create_future() + self._connection_made_called: bool = False + self._connection_lost_called: bool = False + self._user_initiated_close: bool = False + + def _mark_user_closed(self) -> None: + """Record that the application requested close/abort.""" + self._user_initiated_close = True def _mark_broken(self, exc: Exception) -> None: if self._serial is not None: @@ -913,9 +920,23 @@ def _resolve_closed_waiter(self) -> None: if not self._closed_waiter.done(): self._closed_waiter.set_result(None) + def _call_protocol_connection_made(self) -> None: + """Mark connection_made and dispatch to the protocol exactly once.""" + assert not self._connection_made_called + + self._connection_made_called = True + self._protocol.connection_made(self) + def _call_protocol_connection_lost(self, exc: Exception | None) -> None: + """Idempotent dispatch of connection_lost.""" + if self._connection_lost_called: + return + + self._connection_lost_called = True + try: - self._protocol.connection_lost(exc) + if self._connection_made_called: + self._protocol.connection_lost(exc) except (SystemExit, KeyboardInterrupt): raise except BaseException as protocol_exc: diff --git a/serialx/descriptor_transport.py b/serialx/descriptor_transport.py index 875fad5..850992a 100644 --- a/serialx/descriptor_transport.py +++ b/serialx/descriptor_transport.py @@ -69,7 +69,6 @@ def __init__( self._close_task: asyncio.Task[None] | None = None self._open_fut: asyncio.Future[int] | None = None - self._connection_made: bool = False async def _open(self, path: str | os.PathLike[str]) -> None: if self._open_fut is not None: @@ -80,10 +79,15 @@ async def _open(self, path: str | os.PathLike[str]) -> None: ) try: - self._fileno = await self._open_fut + # Shield so that a cancellation of the awaiting task does NOT cancel + # the executor future. Otherwise, when `os.open` completes after the + # cancel, `future.set_result(fd)` is rejected (future already + # cancelled) and the fd is silently leaked. + self._fileno = await asyncio.shield(self._open_fut) except asyncio.CancelledError: - # `os.open` may still finish in the executor after cancellation. If that - # happens, close the resulting fd to avoid leaks. + # `os.open` may still finish in the executor after cancellation. The + # shield kept the underlying future alive, so the done-callback will + # see the fd and arrange to close it. self._open_fut.add_done_callback(self._on_cancelled_open_done) raise except BaseException: @@ -117,7 +121,6 @@ async def _connect( ) -> None: assert self._fileno is not None self._loop.add_reader(self._fileno, self._read_ready) - self._connection_made = True def _read_ready(self) -> None: LOGGER.debug("Event loop woke up reader") @@ -351,6 +354,7 @@ def write_eof(self) -> None: def close(self) -> None: """Close the transport.""" LOGGER.debug("Closing at the request of the application") + self._mark_user_closed() if self._closing: if ( self._fileno is None @@ -402,6 +406,7 @@ def _fatal_error( def abort(self) -> None: """Abort the transport immediately.""" + self._mark_user_closed() self._close(None) def _close(self, exc: Exception | None = None) -> None: @@ -463,20 +468,4 @@ async def _call_connection_lost(self, exc: Exception | None) -> None: LOGGER.debug("Closing file descriptor %s", fileno) await self._loop.run_in_executor(None, _safe_close, fileno) finally: - if self._connection_made: - LOGGER.debug("Calling protocol `connection_lost` with exc=%r", exc) - try: - self._protocol.connection_lost(exc) - except (SystemExit, KeyboardInterrupt): - raise - except BaseException as protocol_exc: - self._loop.call_exception_handler( - { - "message": "protocol.connection_lost() failed", - "exception": protocol_exc, - "transport": self, - "protocol": self._protocol, - } - ) - - self._resolve_closed_waiter() + self._call_protocol_connection_lost(exc) diff --git a/serialx/platforms/serial_esphome.py b/serialx/platforms/serial_esphome.py index dfec983..85b1825 100644 --- a/serialx/platforms/serial_esphome.py +++ b/serialx/platforms/serial_esphome.py @@ -644,7 +644,7 @@ async def _connect( self._register_transport_data_handler() ) - self._protocol.connection_made(self) + self._call_protocol_connection_made() async def _register_transport_data_handler(self) -> Callable[[], None]: """Register `_on_data` on the client's loop and return the unsub.""" @@ -670,6 +670,8 @@ def _on_data(self, msg: SerialProxyDataReceived) -> None: def write(self, data: bytes | bytearray | memoryview) -> None: """Write data to the serial proxy.""" + if self._closing: + return assert self._serial is not None self._serial.write(data) @@ -682,6 +684,7 @@ def close(self) -> None: if self._closing: return self._closing = True + self._mark_user_closed() serial = self._serial if self._unsub is not None: @@ -708,6 +711,9 @@ def close(self) -> None: self._call_protocol_connection_lost(None) return + # TODO: clean shutdown without `wait_closed()` needs a public sync + # force-disconnect on APIClient (aioesphomeapi); today only the + # private `api._connection.force_disconnect()` is sync. self._close_task = self._loop.create_task(self._async_close(api)) def abort(self) -> None: diff --git a/serialx/platforms/serial_posix.py b/serialx/platforms/serial_posix.py index 8eef54d..6931942 100644 --- a/serialx/platforms/serial_posix.py +++ b/serialx/platforms/serial_posix.py @@ -523,7 +523,7 @@ async def _connect( # type: ignore[override] await super()._connect() - self._protocol.connection_made(self) + self._call_protocol_connection_made() async def _flush(self) -> None: """Flush write buffers, waiting until all data is written, internal.""" diff --git a/serialx/platforms/serial_pyodide/__init__.py b/serialx/platforms/serial_pyodide/__init__.py index dcec13f..14cd514 100644 --- a/serialx/platforms/serial_pyodide/__init__.py +++ b/serialx/platforms/serial_pyodide/__init__.py @@ -251,7 +251,7 @@ async def _connect( # type: ignore[override] self._reader_task = self._loop.create_task(self._reader_loop()) self._writer_task = self._loop.create_task(self._writer_loop()) - self._protocol.connection_made(self) + self._call_protocol_connection_made() async def _writer_loop(self) -> None: while True: @@ -308,6 +308,8 @@ async def _set_modem_pins(self, modem_pins: ModemPins) -> None: def write(self, data: bytes | bytearray | memoryview) -> None: """Write data to the transport.""" + if self._closing: + return self._write_buffer_size += len(data) self._write_queue.put_nowait(bytes(data)) @@ -323,6 +325,12 @@ def abort(self) -> None: """Close the transport immediately, discarding pending writes.""" if self._writer_task is not None and not self._writer_task.done(): self._writer_task.cancel() + + while not self._write_queue.empty(): + self._write_queue.get_nowait() + self._write_queue.task_done() + self._write_buffer_size = 0 + self._cleanup(None) def __del__(self) -> None: diff --git a/serialx/platforms/serial_rfc2217/__init__.py b/serialx/platforms/serial_rfc2217/__init__.py index 73b97c4..8a7d06b 100644 --- a/serialx/platforms/serial_rfc2217/__init__.py +++ b/serialx/platforms/serial_rfc2217/__init__.py @@ -722,8 +722,6 @@ def __init__( self._rfc2217_waiters: dict[Rfc2217CmdId, asyncio.Future[Rfc2217Command]] = {} self._tcp_transport: asyncio.Transport | None = None self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None - self._connection_lost_called = False - self._configured = False # -- connection lifecycle ----------------------------------------------- @@ -759,8 +757,7 @@ async def _connect( await self._negotiate() await self._configure_port() - self._configured = True - self._protocol.connection_made(self) + self._call_protocol_connection_made() async def _negotiate(self) -> None: """Perform the initial WILL/DO handshake for COM-PORT-OPTION.""" @@ -882,7 +879,7 @@ def _data_received(self, data: bytes) -> None: for response in responses: self._send_command(response) - if serial_data and self._configured: + if serial_data and self._connection_made_called: self._protocol.data_received(serial_data) async def _send_and_wait(self, cmd: Rfc2217Command) -> Rfc2217Command: @@ -968,23 +965,27 @@ def _tcp_connection_lost(self, exc: Exception | None) -> None: if self._connection_lost_called: return - self._connection_lost_called = True self._closing = True self._tcp_transport = None - if exc is None: - exc = OSError(errno.EIO, "RFC 2217 connection closed by server") - self._mark_broken(exc) + if not self._user_initiated_close: + if exc is None: + exc = OSError(errno.EIO, "RFC 2217 connection closed by server") + self._mark_broken(exc) - # Fail any pending waiters + # Pending in-protocol waiters can't resolve cleanly mid-handshake, so + # always fail them with *some* exception even on a user-initiated close. + waiter_exc = ( + exc if exc is not None else OSError(errno.EIO, "RFC 2217 transport closed") + ) for _expected, telnet_waiter in self._telnet_waiters: if not telnet_waiter.done(): - telnet_waiter.set_exception(exc) + telnet_waiter.set_exception(waiter_exc) self._telnet_waiters.clear() for rfc2217_waiter in self._rfc2217_waiters.values(): if not rfc2217_waiter.done(): - rfc2217_waiter.set_exception(exc) + rfc2217_waiter.set_exception(waiter_exc) self._rfc2217_waiters.clear() if self._serial is not None: @@ -1007,6 +1008,7 @@ def close(self) -> None: if self._connection_lost_called: return self._closing = True + self._mark_user_closed() if self._tcp_transport is not None: self._tcp_transport.close() @@ -1018,6 +1020,7 @@ def abort(self) -> None: if self._connection_lost_called: return self._closing = True + self._mark_user_closed() if self._tcp_transport is not None: self._tcp_transport.abort() diff --git a/serialx/platforms/serial_socket.py b/serialx/platforms/serial_socket.py index ea44fe3..9a8fa37 100644 --- a/serialx/platforms/serial_socket.py +++ b/serialx/platforms/serial_socket.py @@ -219,7 +219,6 @@ def __init__( super().__init__(loop, protocol) self._tcp_transport: asyncio.Transport | None = None self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None - self._connection_lost_called = False async def _connect( # type: ignore[override] self, @@ -254,6 +253,7 @@ async def _connect( # type: ignore[override] host=self._serial._host, port=self._serial._port, ) + self._tcp_transport = tcp_transport if self._connection_lost_called: @@ -264,7 +264,7 @@ async def _connect( # type: ignore[override] self._tcp_transport = None return - self._protocol.connection_made(self) + self._call_protocol_connection_made() def _data_received(self, data: bytes) -> None: """Handle data received from the TCP transport.""" @@ -282,12 +282,12 @@ def _connection_lost(self, exc: Exception | None) -> None: """Handle connection lost from the TCP transport.""" if self._connection_lost_called: return - self._connection_lost_called = True self._closing = True self._tcp_transport = None - if exc is None: - exc = OSError(errno.EIO, "socket closed by peer") - self._mark_broken(exc) + if not self._user_initiated_close: + if exc is None: + exc = OSError(errno.EIO, "socket closed by peer") + self._mark_broken(exc) self._call_protocol_connection_lost(exc) def _tcp_connection_lost(self) -> None: @@ -323,6 +323,7 @@ def abort(self) -> None: if self._connection_lost_called: return self._closing = True + self._mark_user_closed() if self._tcp_transport is not None: self._tcp_transport.abort() @@ -334,6 +335,7 @@ def close(self) -> None: if self._connection_lost_called: return self._closing = True + self._mark_user_closed() if self._tcp_transport is not None: self._tcp_transport.close() diff --git a/serialx/platforms/serial_win32.py b/serialx/platforms/serial_win32.py index 52fcae7..0aaf567 100644 --- a/serialx/platforms/serial_win32.py +++ b/serialx/platforms/serial_win32.py @@ -3,6 +3,7 @@ from __future__ import annotations import asyncio +import functools import logging import os from typing import TYPE_CHECKING, Any, cast @@ -126,6 +127,31 @@ def _safe_close_handle(handle: int) -> None: LOGGER.debug("Failed to close handle %r", handle, exc_info=True) +def CreateFile_detached( + *, + file_name: str, + desired_access: int, + share_mode: int, + creation_disposition: int, + flags_and_attributes: int, +) -> int: + """`CreateFile` returning a raw `int` HANDLE that the caller owns.""" + handle = CreateFile( + file_name, + desired_access, + share_mode, + None, + creation_disposition, + flags_and_attributes, + None, + ) + + # `Detach()` is stubbed `-> Self` but actually returns the underlying int. + # Without `Detach()`, we would get a `PyHANDLE` object that closes the handle on + # `__del__`, masking bugs. + return cast(int, handle.Detach()) + + class Win32Serial(BaseSerial): """Windows serial port implementation using Win32 API.""" @@ -160,17 +186,13 @@ def _open(self) -> None: share_mode = 0 if self._exclusive else FILE_SHARE_READ | FILE_SHARE_WRITE try: - handle = CreateFile( - path, - GENERIC_READ | GENERIC_WRITE, - share_mode, - None, - OPEN_EXISTING, - FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, - None, + self._handle = CreateFile_detached( + file_name=path, + desired_access=GENERIC_READ | GENERIC_WRITE, + share_mode=share_mode, + creation_disposition=OPEN_EXISTING, + flags_and_attributes=FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, ) - - self._handle = cast(int, handle) except pywintypes.error as e: raise OSError(e.winerror, e.strerror, path) from e @@ -454,6 +476,7 @@ def __init__( self._close_future: asyncio.Future[None] | None = None self._closing: bool = False self._connect_in_progress: bool = False + self._connection_made_waiter: asyncio.Future[None] | None = None def serial_close(self) -> None: """Close the serial port.""" @@ -488,11 +511,17 @@ def protocol_connection_made(self, transport: asyncio.Transport) -> None: """Forward connection_made to the protocol.""" # Ignore `transport` and pass self instead - self._protocol.connection_made(self) + self._call_protocol_connection_made() + + if ( + self._connection_made_waiter is not None + and not self._connection_made_waiter.done() + ): + self._connection_made_waiter.set_result(None) def protocol_connection_lost(self, exc: Exception | None) -> None: """Forward connection_lost to the protocol.""" - self._resolve_closed_waiter() + self._call_protocol_connection_lost(exc) def protocol_pause_writing(self) -> None: """Forward pause_writing to the protocol.""" @@ -526,10 +555,10 @@ def _on_cancelled_open_done(self, open_fut: asyncio.Future[int]) -> None: try: handle = open_fut.result() except BaseException: - self._maybe_resolve_closed_waiter() - return + pass + else: + _safe_close_handle(handle) - _safe_close_handle(handle) self._maybe_resolve_closed_waiter() async def _open( @@ -542,22 +571,28 @@ async def _open( normalized_path = _normalize_windows_port_path(path) share_mode = 0 if exclusive else FILE_SHARE_READ | FILE_SHARE_WRITE - open_fut = self._loop.run_in_executor( + # Use `CreateFile_detached` so the future's result is a plain `int`, + # matching the shape of `os.open` on POSIX. The PyHANDLE wrapper + # never escapes the executor thread, so its `tp_dealloc` can't paper + # over a missing explicit close on cancel. + self._open_fut = self._loop.run_in_executor( None, - lambda: CreateFile( - normalized_path, - GENERIC_READ | GENERIC_WRITE, - share_mode, - None, - OPEN_EXISTING, - FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, - None, + functools.partial( + CreateFile_detached, + file_name=normalized_path, + desired_access=GENERIC_READ | GENERIC_WRITE, + share_mode=share_mode, + creation_disposition=OPEN_EXISTING, + flags_and_attributes=FILE_ATTRIBUTE_NORMAL | FILE_FLAG_OVERLAPPED, ), ) - self._open_fut = cast(asyncio.Future[int], open_fut) try: - handle = await self._open_fut + # Shield so a cancellation of the awaiting task doesn't cancel the + # executor future. Otherwise, when `CreateFile` completes after the + # cancel, `wrap_future` drops the result silently and the HANDLE + # is leaked. + handle = await asyncio.shield(self._open_fut) except asyncio.CancelledError: self._open_fut.add_done_callback(self._on_cancelled_open_done) raise @@ -611,6 +646,7 @@ async def _connect( # Use the internal _make_duplex_pipe_transport to create a true overlapping # bidirectional transport on the single handle. assert hasattr(self._loop, "_make_duplex_pipe_transport") + self._connection_made_waiter = self._loop.create_future() self._internal_transport = self._loop._make_duplex_pipe_transport( # Proxy access to serial and protocol attributes through this instance sock=_MethodProxy( @@ -635,6 +671,11 @@ async def _connect( ) if self._closing: self._internal_transport.close() # type: ignore[unreachable] + return + + # The internal duplex pipe transport schedules connection_made via + # call_soon. Wait for it so callers can assume MADE upon return. + await self._connection_made_waiter except BaseException: if self._handle is not None: await self._loop.run_in_executor(None, _safe_close_handle, self._handle) diff --git a/tests/common.py b/tests/common.py index 91a43c1..95ededf 100644 --- a/tests/common.py +++ b/tests/common.py @@ -14,6 +14,7 @@ import subprocess import sys import tempfile +import threading import time from typing import IO, Any @@ -224,7 +225,6 @@ def _snapshot_fds() -> set[int]: def check_fd_leaks() -> Iterator[None]: """Fail if any file descriptor is opened in this block without being closed.""" before = _snapshot_fds() - try: yield finally: @@ -321,7 +321,10 @@ def create_adapter_pair(left: str, right: str) -> Iterator[tuple[str, str]]: f.fd, proc.info["cmdline"], ) - except (psutil.NoSuchProcess, psutil.AccessDenied): # noqa: PERF203 + except ( # noqa: PERF203 + psutil.NoSuchProcess, + psutil.AccessDenied, + ): pass # Check our own process @@ -593,6 +596,7 @@ def create_hub4com_pair( ] procs = [] + drain_threads: list[threading.Thread] = [] try: for adapter in (left_adapter, right_adapter): @@ -618,6 +622,16 @@ def create_hub4com_pair( name="hub4com", ) + # Drain hub4com's stdout/stderr. Otherwise the OS pipe buffers fill after a + # handful of sessions and hub4com blocks. + for proc in procs: + for stream in (proc.stdout, proc.stderr): + t = threading.Thread( + target=lambda s=stream: list(iter(s.readline, b"")), + ) + t.start() + drain_threads.append(t) + left, right = [_get_listening_ports(proc.pid)[0] for proc in procs] yield ( @@ -630,6 +644,9 @@ def create_hub4com_pair( proc.terminate() proc.wait() + for t in drain_threads: + t.join() + @contextlib.asynccontextmanager async def async_create_serial_pair( diff --git a/tests/test_async_lifecycle.py b/tests/test_async_lifecycle.py new file mode 100644 index 0000000..de22ee6 --- /dev/null +++ b/tests/test_async_lifecycle.py @@ -0,0 +1,706 @@ +"""Granular lifecycle and race-condition tests for async serial transports.""" + +from __future__ import annotations + +import asyncio +from collections.abc import Callable, Iterator +import contextlib +import enum +import gc +import importlib +import sys +import threading +from typing import Any +from unittest.mock import patch +import warnings + +import pytest + +from serialx import BaseSerialTransport, create_serial_connection +from tests.common import SerialBackend, SerialPair + + +class ProtocolState(enum.Enum): + """Asyncio Protocol lifecycle state.""" + + INIT = "init" # constructed, before connection_made + MADE = "made" # connection_made called, before connection_lost + LOST = "lost" # connection_lost called (terminal) + + +class RecordingProtocol(asyncio.Protocol): + """asyncio.Protocol that enforces lifecycle invariants and records data.""" + + connection_made_transport: BaseSerialTransport | None + connection_lost_exc: Exception | None + + def __init__(self) -> None: + """Initialize in the INIT state.""" + self._state = ProtocolState.INIT + self.violations: list[str] = [] + self.connection_made_transport = None + self.connection_lost_exc = None + self.data_received_chunks: list[bytes] = [] + self._state_waiters: dict[ProtocolState, list[asyncio.Future[None]]] = {} + + def _set_state(self, state: ProtocolState) -> None: + self._state = state + for fut in self._state_waiters.pop(state, []): + if not fut.done(): + fut.set_result(None) + + async def wait_for_state(self, state: ProtocolState) -> None: + """Resolve once the protocol has reached `state`.""" + if self._state is state: + return + fut: asyncio.Future[None] = asyncio.get_running_loop().create_future() + self._state_waiters.setdefault(state, []).append(fut) + await fut + + @property + def state(self) -> ProtocolState: + """Current lifecycle state.""" + return self._state + + def assert_state(self, expected: ProtocolState) -> None: + """Assert the protocol is in `expected` state.""" + assert self._state is expected, f"state={self._state}, expected={expected}" + + def _require_state(self, callback: str, *expected: ProtocolState) -> None: + if self._state in expected: + return + allowed = ", ".join(s.value for s in expected) + self.violations.append( + f"{callback} called in state {self._state.value!r}; allowed: {{{allowed}}}" + ) + + def connection_made(self, transport: asyncio.BaseTransport) -> None: + """Enforce INIT -> MADE.""" + self._require_state("connection_made", ProtocolState.INIT) + assert isinstance(transport, BaseSerialTransport) + self.connection_made_transport = transport + self._set_state(ProtocolState.MADE) + + def connection_lost(self, exc: Exception | None) -> None: + """Enforce MADE -> LOST.""" + self._require_state("connection_lost", ProtocolState.MADE) + self.connection_lost_exc = exc + self._set_state(ProtocolState.LOST) + + def data_received(self, data: bytes) -> None: + """Record an incoming chunk; only valid in MADE.""" + self._require_state("data_received", ProtocolState.MADE) + self.data_received_chunks.append(data) + + def eof_received(self) -> bool | None: + """Record an EOF; only valid in MADE.""" + self._require_state("eof_received", ProtocolState.MADE) + return None + + def pause_writing(self) -> None: + """Only valid in MADE.""" + self._require_state("pause_writing", ProtocolState.MADE) + + def resume_writing(self) -> None: + """Only valid in MADE.""" + self._require_state("resume_writing", ProtocolState.MADE) + + def assert_clean(self) -> None: + """Fail the test if any state violation was recorded.""" + if self.violations: + pytest.fail("Protocol violations:\n " + "\n ".join(self.violations)) + + @property + def total_received(self) -> bytes: + """Concatenation of all received chunks.""" + return b"".join(self.data_received_chunks) + + +# --- Successful lifecycle: callbacks fire exactly once --- + + +async def test_lifecycle_normal_close_callbacks(serial_pair: SerialPair) -> None: + """Connect + graceful close: state machine traverses INIT -> MADE -> LOST.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + protocol.assert_state(ProtocolState.MADE) + assert protocol.connection_made_transport is transport + + transport.close() + await transport.wait_closed() + + protocol.assert_state(ProtocolState.LOST) + assert protocol.connection_lost_exc is None + protocol.assert_clean() + + +async def test_lifecycle_abort_callbacks(serial_pair: SerialPair) -> None: + """Connect + abort: traverses INIT -> MADE -> LOST.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + transport.abort() + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + assert protocol.connection_lost_exc is None + protocol.assert_clean() + + +# --- Cancellation during connect --- + + +async def test_lifecycle_cancel_during_connect_no_callbacks( + serial_pair: SerialPair, +) -> None: + """Cancel mid-connect: ends in INIT or LOST, never an inconsistent state.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + connect_task = asyncio.create_task( + create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + ) + await asyncio.sleep(0) + connect_task.cancel() + + with pytest.raises(asyncio.CancelledError): + await connect_task + + assert protocol.state in (ProtocolState.INIT, ProtocolState.LOST) + protocol.assert_clean() + + +# --- Idempotency under repeated close/abort --- + + +async def test_lifecycle_close_close_one_connection_lost( + serial_pair: SerialPair, +) -> None: + """Two close() calls: state machine ensures connection_lost fires once.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + transport.close() + transport.close() + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +async def test_lifecycle_abort_after_close_one_connection_lost( + serial_pair: SerialPair, +) -> None: + """abort() after close().""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + transport.close() + transport.abort() + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +async def test_lifecycle_close_after_abort_one_connection_lost( + serial_pair: SerialPair, +) -> None: + """close() after abort().""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + transport.abort() + transport.close() + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +# --- Drain vs. abort semantics --- + + +async def test_lifecycle_close_drains_pending_writes( + serial_pair: SerialPair, +) -> None: + """close() with buffered data delivers all bytes to the peer.""" + loop = asyncio.get_running_loop() + payload = bytes(range(256)) * 16 # 4 KiB + sender_proto = RecordingProtocol() + receiver_proto = RecordingProtocol() + + sender, _ = await create_serial_connection( + loop, lambda: sender_proto, serial_pair.left, baudrate=115200 + ) + receiver, _ = await create_serial_connection( + loop, lambda: receiver_proto, serial_pair.right, baudrate=115200 + ) + + try: + sender.write(payload) + sender.close() # drain semantics + await sender.wait_closed() + + while len(receiver_proto.total_received) < len(payload): + await asyncio.sleep(0.05) + + assert receiver_proto.total_received == payload + finally: + receiver.close() + await receiver.wait_closed() + sender_proto.assert_clean() + receiver_proto.assert_clean() + + +async def test_lifecycle_abort_during_drain_escalates( + serial_pair: SerialPair, +) -> None: + """abort() called while close()'s drain is pending must escalate to abort semantics.""" + loop = asyncio.get_running_loop() + sender_proto = RecordingProtocol() + receiver_proto = RecordingProtocol() + + sender, _ = await create_serial_connection( + loop, lambda: sender_proto, serial_pair.left, baudrate=115200 + ) + receiver, _ = await create_serial_connection( + loop, lambda: receiver_proto, serial_pair.right, baudrate=115200 + ) + + try: + # Large payload to overflow the kernel TTY buffer and force user-space buffering. + sender.write(b"\x55" * (4 * 1024 * 1024)) + + if sender.get_write_buffer_size() == 0: + pytest.skip("Backend absorbed the entire write synchronously") + + # close() starts a drain because buffer is non-empty. + sender.close() + assert sender.is_closing() + assert sender.get_write_buffer_size() > 0, ( + "close() with a pending buffer must NOT clear it synchronously" + ) + + # abort() must escalate: buffer cleared, transport heading for closed. + sender.abort() + assert sender.get_write_buffer_size() == 0, ( + "abort() during drain must clear the buffer synchronously" + ) + + await sender.wait_closed() + assert sender_proto.state is ProtocolState.LOST + finally: + sender.close() + receiver.close() + await sender.wait_closed() + await receiver.wait_closed() + sender_proto.assert_clean() + receiver_proto.assert_clean() + + +# --- is_closing() state machine --- + + +async def test_lifecycle_is_closing_states(serial_pair: SerialPair) -> None: + """is_closing() reflects the lifecycle: False -> True at close request -> True after wait.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + assert transport.is_closing() is False + + transport.close() + assert transport.is_closing() is True + + await transport.wait_closed() + assert transport.is_closing() is True + protocol.assert_clean() + + +# --- Concurrent wait_closed waiters --- + + +async def test_lifecycle_concurrent_wait_closed(serial_pair: SerialPair) -> None: + """All concurrent wait_closed() awaiters resolve when the transport closes.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + waiters = [asyncio.create_task(transport.wait_closed()) for _ in range(5)] + await asyncio.sleep(0) + transport.close() + + results = await asyncio.gather(*waiters) + assert results == [None] * 5 + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +# --- Open failure path --- + + +async def test_lifecycle_open_failure_no_callbacks() -> None: + """A failed os.open: connect raises, with no protocol callbacks fired.""" + if sys.platform == "emscripten": + pytest.skip("No POSIX/Windows-style device paths under Pyodide") + + loop = asyncio.get_running_loop() + path = "COM25" if sys.platform == "win32" else "/dev/this_port_does_not_exist" + protocol = RecordingProtocol() + + with pytest.raises(OSError): + await create_serial_connection(loop, lambda: protocol, path, baudrate=115200) + + assert protocol.state is ProtocolState.INIT + protocol.assert_clean() + + +# --- Close from inside connection_made (re-entrancy) --- + + +async def test_lifecycle_close_from_connection_made(serial_pair: SerialPair) -> None: + """A protocol that calls close() inside connection_made: lost still fires once.""" + loop = asyncio.get_running_loop() + + class CloseInsideMade(RecordingProtocol): + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + assert isinstance(transport, BaseSerialTransport) + transport.close() + + protocol = CloseInsideMade() + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +async def test_lifecycle_abort_from_connection_made(serial_pair: SerialPair) -> None: + """A protocol that calls abort() inside connection_made: lost still fires once.""" + loop = asyncio.get_running_loop() + + class AbortInsideMade(RecordingProtocol): + def connection_made(self, transport: asyncio.BaseTransport) -> None: + super().connection_made(transport) + assert isinstance(transport, BaseSerialTransport) + transport.write(b"this should be discarded") + transport.abort() + + protocol = AbortInsideMade() + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +# --- Multiple cycles on the same path --- + + +async def test_lifecycle_repeated_open_close_cycles( + serial_pair: SerialPair, +) -> None: + """Repeated open/close cycles each produce one made + one lost callback.""" + loop = asyncio.get_running_loop() + + def make_factory(p: RecordingProtocol) -> Callable[[], RecordingProtocol]: + return lambda: p + + for _ in range(5): + protocol = RecordingProtocol() + transport, _ = await create_serial_connection( + loop, + make_factory(protocol), + serial_pair.left, + baudrate=115200, + ) + transport.close() + await transport.wait_closed() + + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +# --- Writes after close are silently dropped --- + + +async def test_lifecycle_write_after_close_is_dropped(serial_pair: SerialPair) -> None: + """Writes after close() are silently dropped (no exception, no delivery).""" + loop = asyncio.get_running_loop() + sender_proto = RecordingProtocol() + receiver_proto = RecordingProtocol() + + sender, _ = await create_serial_connection( + loop, lambda: sender_proto, serial_pair.left, baudrate=115200 + ) + receiver, _ = await create_serial_connection( + loop, lambda: receiver_proto, serial_pair.right, baudrate=115200 + ) + + try: + sender.close() + # These should be silently ignored + for _ in range(10): + sender.write(b"after close") + + await sender.wait_closed() + + # Drain the receiver briefly to make sure nothing leaked through. + await asyncio.sleep(0.1) + assert receiver_proto.data_received_chunks == [] + finally: + receiver.close() + await receiver.wait_closed() + sender_proto.assert_clean() + receiver_proto.assert_clean() + + +# --- Connection ordering: connection_made must precede any data_received --- + + +async def test_lifecycle_data_received_after_connection_made( + serial_pair: SerialPair, +) -> None: + """data_received must arrive only after connection_made (enforced by state machine).""" + loop = asyncio.get_running_loop() + + left_proto = RecordingProtocol() + right_proto = RecordingProtocol() + + left, _ = await create_serial_connection( + loop, lambda: left_proto, serial_pair.left, baudrate=115200 + ) + right, _ = await create_serial_connection( + loop, lambda: right_proto, serial_pair.right, baudrate=115200 + ) + + try: + right.write(b"hello") + await right.flush() + + while left_proto.total_received != b"hello": + await asyncio.sleep(0.01) + + assert left_proto.total_received == b"hello" + finally: + left.close() + right.close() + await left.wait_closed() + await right.wait_closed() + left_proto.assert_clean() + right_proto.assert_clean() + + +# --- _Pure_ wait_closed / close ordering, no transport state checks --- + + +async def test_lifecycle_wait_closed_before_close_blocks( + serial_pair: SerialPair, +) -> None: + """wait_closed() awaited before close() must not resolve until close() is called.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + + waiter = asyncio.create_task(transport.wait_closed()) + await asyncio.sleep(0.1) + assert not waiter.done(), "wait_closed must not resolve before close()" + + transport.close() + await waiter + assert protocol.state is ProtocolState.LOST + protocol.assert_clean() + + +# --- Connect kwargs: a misconfigured kwarg surfaces as an exception --- + + +async def test_lifecycle_invalid_kwarg_surfaces_no_callbacks( + serial_pair: SerialPair, +) -> None: + """A bad kwarg during connect raises and produces no protocol callbacks.""" + if SerialBackend.SOCKET in serial_pair.backends: + pytest.skip("socket transport does not validate serial settings") + + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + with pytest.raises(Exception): + await create_serial_connection( + loop, + lambda: protocol, + serial_pair.left, + baudrate=115200, + byte_size=99, + ) + + assert protocol.state in (ProtocolState.INIT, ProtocolState.LOST) + protocol.assert_clean() + + +# --- Cancellation race: fd must not leak when os.open is mid-syscall --- + + +@contextlib.contextmanager +def patch_slow( + *targets: str, +) -> Iterator[tuple[threading.Event, threading.Event, threading.Event]]: + """Patch each target callable to block on `proceed` after signaling `started`. + + Yields `(started, proceed, completed)`. `completed` is set after the real + underlying call has returned in every patched call site — useful when the + test needs to wait for the executor thread to finish before checking for + leaked resources. + """ + started = threading.Event() + proceed = threading.Event() + completed = threading.Event() + + def make_slow(real_fn: Callable[..., Any]) -> Callable[..., Any]: + def slow(*args: Any, **kwargs: Any) -> Any: + started.set() + if not proceed.wait(timeout=5.0): + raise TimeoutError("test setup: proceed never released") + try: + return real_fn(*args, **kwargs) + finally: + completed.set() + + return slow + + with contextlib.ExitStack() as stack: + for target in targets: + module_path, _, attr_name = target.rpartition(".") + module = importlib.import_module(module_path) + real = getattr(module, attr_name) + stack.enter_context(patch(target, new=make_slow(real))) + + try: + yield started, proceed, completed + finally: + proceed.set() + + +async def test_lifecycle_no_fd_leak_when_internal_task_cancelled_during_open( + serial_pair: SerialPair, +) -> None: + """Cancelling internal tasks during connect must not leak resources.""" + loop = asyncio.get_running_loop() + protocol = RecordingProtocol() + + slow_targets = ["os.open"] + + if sys.platform == "win32": + slow_targets.append("serialx.platforms.serial_win32.CreateFile") + + with patch_slow(*slow_targets) as (started, proceed, completed): + existing_tasks = asyncio.all_tasks(loop) + + async def connect() -> None: + transport, _ = await create_serial_connection( + loop, lambda: protocol, serial_pair.left, baudrate=115200 + ) + transport.close() + await transport.wait_closed() + + connect_task = asyncio.create_task(connect()) + + # Wait briefly for a patched syscall to fire. Network backends never + # hit one; for those this just gives the connect a head start. + await loop.run_in_executor(None, started.wait, 1.0) + + # Cancel every transport-internal in-flight task. When a patched + # syscall is mid-flight this hits the cancel window. + for task in asyncio.all_tasks(loop) - existing_tasks: + if task is asyncio.current_task() or task.done(): + continue + task.cancel() + + # Release any blocked syscall. The real call returns its handle/fd, + # and the executor tries to deliver to a (possibly cancelled) future. + proceed.set() + + with contextlib.suppress(asyncio.CancelledError): + await connect_task + + # Make sure the executor thread fully returned from the real syscall + # so that any leaked fd/handle is visible to teardown's snapshot. + if started.is_set(): + await loop.run_in_executor(None, completed.wait, 5.0) + + +def test_lifecycle_close_without_wait_closed_no_warnings( + serial_pair: SerialPair, +) -> None: + """close() and then shutdown doesn't log `Task was destroyed but it is pending`.""" + if SerialBackend.ESPHOME_HOST in serial_pair.backends: + pytest.skip( + "TODO: aioesphomeapi has no public sync force-disconnect; " + "see serial_esphome.py" + ) + + handler_calls: list[dict[str, Any]] = [] + + async def main() -> None: + loop = asyncio.get_running_loop() + transport, _ = await create_serial_connection( + loop, asyncio.Protocol, serial_pair.left, baudrate=115200 + ) + transport.close() + # Intentionally NOT awaiting `wait_closed` + + with warnings.catch_warnings(record=True) as caught_warnings: + warnings.simplefilter("always") + + loop = asyncio.new_event_loop() + loop.set_exception_handler(lambda _loop, ctx: handler_calls.append(ctx)) + + try: + loop.run_until_complete(main()) + finally: + loop.close() + + # Force GC so Task.__del__ runs and any "destroyed but pending" diagnostic + # reaches the exception handler before we check. + gc.collect() + + assert not caught_warnings