Skip to content
Open
23 changes: 22 additions & 1 deletion serialx/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,13 @@ def __init__(
self._serial: BaseSerial | None = None
self._closing: bool = False
self._closed_waiter: asyncio.Future[None] = loop.create_future()
self._connection_made_called: bool = False
self._connection_lost_called: bool = False
self._user_initiated_close: bool = False

def _mark_user_closed(self) -> None:
"""Record that the application requested close/abort."""
self._user_initiated_close = True

def _mark_broken(self, exc: Exception) -> None:
if self._serial is not None:
Expand All @@ -913,9 +920,23 @@ def _resolve_closed_waiter(self) -> None:
if not self._closed_waiter.done():
self._closed_waiter.set_result(None)

def _call_protocol_connection_made(self) -> None:
"""Mark connection_made and dispatch to the protocol exactly once."""
assert not self._connection_made_called

self._connection_made_called = True
self._protocol.connection_made(self)

def _call_protocol_connection_lost(self, exc: Exception | None) -> None:
"""Idempotent dispatch of connection_lost."""
if self._connection_lost_called:
return

self._connection_lost_called = True

try:
self._protocol.connection_lost(exc)
if self._connection_made_called:
self._protocol.connection_lost(exc)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as protocol_exc:
Expand Down
33 changes: 11 additions & 22 deletions serialx/descriptor_transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def __init__(

self._close_task: asyncio.Task[None] | None = None
self._open_fut: asyncio.Future[int] | None = None
self._connection_made: bool = False

async def _open(self, path: str | os.PathLike[str]) -> None:
if self._open_fut is not None:
Expand All @@ -80,10 +79,15 @@ async def _open(self, path: str | os.PathLike[str]) -> None:
)

try:
self._fileno = await self._open_fut
# Shield so that a cancellation of the awaiting task does NOT cancel
# the executor future. Otherwise, when `os.open` completes after the
# cancel, `future.set_result(fd)` is rejected (future already
# cancelled) and the fd is silently leaked.
self._fileno = await asyncio.shield(self._open_fut)
except asyncio.CancelledError:
# `os.open` may still finish in the executor after cancellation. If that
# happens, close the resulting fd to avoid leaks.
# `os.open` may still finish in the executor after cancellation. The
# shield kept the underlying future alive, so the done-callback will
# see the fd and arrange to close it.
self._open_fut.add_done_callback(self._on_cancelled_open_done)
raise
except BaseException:
Expand Down Expand Up @@ -117,7 +121,6 @@ async def _connect(
) -> None:
assert self._fileno is not None
self._loop.add_reader(self._fileno, self._read_ready)
self._connection_made = True

def _read_ready(self) -> None:
LOGGER.debug("Event loop woke up reader")
Expand Down Expand Up @@ -351,6 +354,7 @@ def write_eof(self) -> None:
def close(self) -> None:
"""Close the transport."""
LOGGER.debug("Closing at the request of the application")
self._mark_user_closed()
if self._closing:
if (
self._fileno is None
Expand Down Expand Up @@ -402,6 +406,7 @@ def _fatal_error(

def abort(self) -> None:
"""Abort the transport immediately."""
self._mark_user_closed()
self._close(None)

def _close(self, exc: Exception | None = None) -> None:
Expand Down Expand Up @@ -463,20 +468,4 @@ async def _call_connection_lost(self, exc: Exception | None) -> None:
LOGGER.debug("Closing file descriptor %s", fileno)
await self._loop.run_in_executor(None, _safe_close, fileno)
finally:
if self._connection_made:
LOGGER.debug("Calling protocol `connection_lost` with exc=%r", exc)
try:
self._protocol.connection_lost(exc)
except (SystemExit, KeyboardInterrupt):
raise
except BaseException as protocol_exc:
self._loop.call_exception_handler(
{
"message": "protocol.connection_lost() failed",
"exception": protocol_exc,
"transport": self,
"protocol": self._protocol,
}
)

self._resolve_closed_waiter()
self._call_protocol_connection_lost(exc)
8 changes: 7 additions & 1 deletion serialx/platforms/serial_esphome.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ async def _connect(
self._register_transport_data_handler()
)

self._protocol.connection_made(self)
self._call_protocol_connection_made()

async def _register_transport_data_handler(self) -> Callable[[], None]:
"""Register `_on_data` on the client's loop and return the unsub."""
Expand All @@ -670,6 +670,8 @@ def _on_data(self, msg: SerialProxyDataReceived) -> None:

def write(self, data: bytes | bytearray | memoryview) -> None:
"""Write data to the serial proxy."""
if self._closing:
return
assert self._serial is not None
self._serial.write(data)

Expand All @@ -682,6 +684,7 @@ def close(self) -> None:
if self._closing:
return
self._closing = True
self._mark_user_closed()

serial = self._serial
if self._unsub is not None:
Expand All @@ -708,6 +711,9 @@ def close(self) -> None:
self._call_protocol_connection_lost(None)
return

# TODO: clean shutdown without `wait_closed()` needs a public sync
# force-disconnect on APIClient (aioesphomeapi); today only the
# private `api._connection.force_disconnect()` is sync.
self._close_task = self._loop.create_task(self._async_close(api))

def abort(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion serialx/platforms/serial_posix.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ async def _connect( # type: ignore[override]

await super()._connect()

self._protocol.connection_made(self)
self._call_protocol_connection_made()

async def _flush(self) -> None:
"""Flush write buffers, waiting until all data is written, internal."""
Expand Down
10 changes: 9 additions & 1 deletion serialx/platforms/serial_pyodide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def _connect( # type: ignore[override]
self._reader_task = self._loop.create_task(self._reader_loop())
self._writer_task = self._loop.create_task(self._writer_loop())

self._protocol.connection_made(self)
self._call_protocol_connection_made()

async def _writer_loop(self) -> None:
while True:
Expand Down Expand Up @@ -308,6 +308,8 @@ async def _set_modem_pins(self, modem_pins: ModemPins) -> None:

def write(self, data: bytes | bytearray | memoryview) -> None:
"""Write data to the transport."""
if self._closing:
return
self._write_buffer_size += len(data)
self._write_queue.put_nowait(bytes(data))

Expand All @@ -323,6 +325,12 @@ def abort(self) -> None:
"""Close the transport immediately, discarding pending writes."""
if self._writer_task is not None and not self._writer_task.done():
self._writer_task.cancel()

while not self._write_queue.empty():
self._write_queue.get_nowait()
self._write_queue.task_done()
self._write_buffer_size = 0

self._cleanup(None)

def __del__(self) -> None:
Expand Down
27 changes: 15 additions & 12 deletions serialx/platforms/serial_rfc2217/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,6 @@ def __init__(
self._rfc2217_waiters: dict[Rfc2217CmdId, asyncio.Future[Rfc2217Command]] = {}
self._tcp_transport: asyncio.Transport | None = None
self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None
self._connection_lost_called = False
self._configured = False

# -- connection lifecycle -----------------------------------------------

Expand Down Expand Up @@ -759,8 +757,7 @@ async def _connect(
await self._negotiate()
await self._configure_port()

self._configured = True
self._protocol.connection_made(self)
self._call_protocol_connection_made()

async def _negotiate(self) -> None:
"""Perform the initial WILL/DO handshake for COM-PORT-OPTION."""
Expand Down Expand Up @@ -882,7 +879,7 @@ def _data_received(self, data: bytes) -> None:
for response in responses:
self._send_command(response)

if serial_data and self._configured:
if serial_data and self._connection_made_called:
self._protocol.data_received(serial_data)

async def _send_and_wait(self, cmd: Rfc2217Command) -> Rfc2217Command:
Expand Down Expand Up @@ -968,23 +965,27 @@ def _tcp_connection_lost(self, exc: Exception | None) -> None:

if self._connection_lost_called:
return
self._connection_lost_called = True
self._closing = True
self._tcp_transport = None

if exc is None:
exc = OSError(errno.EIO, "RFC 2217 connection closed by server")
self._mark_broken(exc)
if not self._user_initiated_close:
if exc is None:
exc = OSError(errno.EIO, "RFC 2217 connection closed by server")
self._mark_broken(exc)

# Fail any pending waiters
# Pending in-protocol waiters can't resolve cleanly mid-handshake, so
# always fail them with *some* exception even on a user-initiated close.
waiter_exc = (
exc if exc is not None else OSError(errno.EIO, "RFC 2217 transport closed")
)
for _expected, telnet_waiter in self._telnet_waiters:
if not telnet_waiter.done():
telnet_waiter.set_exception(exc)
telnet_waiter.set_exception(waiter_exc)
self._telnet_waiters.clear()

for rfc2217_waiter in self._rfc2217_waiters.values():
if not rfc2217_waiter.done():
rfc2217_waiter.set_exception(exc)
rfc2217_waiter.set_exception(waiter_exc)
self._rfc2217_waiters.clear()

if self._serial is not None:
Expand All @@ -1007,6 +1008,7 @@ def close(self) -> None:
if self._connection_lost_called:
return
self._closing = True
self._mark_user_closed()

if self._tcp_transport is not None:
self._tcp_transport.close()
Expand All @@ -1018,6 +1020,7 @@ def abort(self) -> None:
if self._connection_lost_called:
return
self._closing = True
self._mark_user_closed()

if self._tcp_transport is not None:
self._tcp_transport.abort()
Expand Down
14 changes: 8 additions & 6 deletions serialx/platforms/serial_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,6 @@ def __init__(
super().__init__(loop, protocol)
self._tcp_transport: asyncio.Transport | None = None
self._tcp_connection_lost_waiter: asyncio.Future[None] | None = None
self._connection_lost_called = False

async def _connect( # type: ignore[override]
self,
Expand Down Expand Up @@ -254,6 +253,7 @@ async def _connect( # type: ignore[override]
host=self._serial._host,
port=self._serial._port,
)

self._tcp_transport = tcp_transport

if self._connection_lost_called:
Expand All @@ -264,7 +264,7 @@ async def _connect( # type: ignore[override]
self._tcp_transport = None
return

self._protocol.connection_made(self)
self._call_protocol_connection_made()

def _data_received(self, data: bytes) -> None:
"""Handle data received from the TCP transport."""
Expand All @@ -282,12 +282,12 @@ def _connection_lost(self, exc: Exception | None) -> None:
"""Handle connection lost from the TCP transport."""
if self._connection_lost_called:
return
self._connection_lost_called = True
self._closing = True
self._tcp_transport = None
if exc is None:
exc = OSError(errno.EIO, "socket closed by peer")
self._mark_broken(exc)
if not self._user_initiated_close:
if exc is None:
exc = OSError(errno.EIO, "socket closed by peer")
self._mark_broken(exc)
self._call_protocol_connection_lost(exc)

def _tcp_connection_lost(self) -> None:
Expand Down Expand Up @@ -323,6 +323,7 @@ def abort(self) -> None:
if self._connection_lost_called:
return
self._closing = True
self._mark_user_closed()

if self._tcp_transport is not None:
self._tcp_transport.abort()
Expand All @@ -334,6 +335,7 @@ def close(self) -> None:
if self._connection_lost_called:
return
self._closing = True
self._mark_user_closed()

if self._tcp_transport is not None:
self._tcp_transport.close()
Expand Down
Loading
Loading