diff --git a/overlay/server.py b/overlay/server.py index b6c4784..38579db 100644 --- a/overlay/server.py +++ b/overlay/server.py @@ -45,6 +45,7 @@ def __init__(self, host: str = "0.0.0.0", port: int = 8080) -> None: self.host = host self.port = port self._clients: set[WebSocket] = set() + self._last_state: dict[str, Any] | None = None self._app = Starlette( routes=[ Route("/", self._serve_index), @@ -77,6 +78,12 @@ async def _serve_index(self, request: Request) -> Response: async def _websocket_endpoint(self, websocket: WebSocket) -> None: await websocket.accept() self._clients.add(websocket) + if self._last_state is not None: + try: + await websocket.send_text(json.dumps(self._last_state)) + except Exception: + self._clients.discard(websocket) + return try: while True: await websocket.receive_text() @@ -92,12 +99,16 @@ async def _websocket_endpoint(self, websocket: WebSocket) -> None: async def broadcast(self, state: dict[str, Any]) -> None: """Send *state* as a JSON text frame to every connected WebSocket client. + The state is also cached so that clients connecting after the last + broadcast (e.g. after a page refresh) receive it immediately. + Dead connections are silently removed. Args: state: A JSON-serialisable dict (e.g. from :func:`~overlay.state.serialize_game_state`). """ + self._last_state = state if not self._clients: return message = json.dumps(state) diff --git a/tests/test_overlay.py b/tests/test_overlay.py index 86e0cac..ba5010d 100644 --- a/tests/test_overlay.py +++ b/tests/test_overlay.py @@ -238,3 +238,31 @@ def test_broadcast_to_no_clients_does_not_raise(self): server = OverlayServer() # Should not raise even with zero clients asyncio.run(server.broadcast({"status": "idle"})) + + def test_new_client_receives_last_state_on_connect(self): + """A client that connects after a broadcast immediately gets the cached state.""" + import asyncio + + server = OverlayServer() + client = TestClient(server.app) + + payload = {"status": "running", "attempt_count": 5} + asyncio.run(server.broadcast(payload)) + + with client.websocket_connect("/ws") as ws: + data = ws.receive_text() + + decoded = json.loads(data) + assert decoded["status"] == "running" + assert decoded["attempt_count"] == 5 + + def test_new_client_receives_no_message_when_no_broadcast_yet(self): + """A client that connects before any broadcast does not get a spurious message.""" + server = OverlayServer() + client = TestClient(server.app) + + # Connect and verify that no state is cached, meaning the connection + # handler will not attempt to send an initial message to the client. + # (_last_state is the guard for the send; None means no send occurs.) + with client.websocket_connect("/ws") as ws: + assert server._last_state is None