diff --git a/binance/ws/keepalive_websocket.py b/binance/ws/keepalive_websocket.py index ccc663cc..7a50cacb 100644 --- a/binance/ws/keepalive_websocket.py +++ b/binance/ws/keepalive_websocket.py @@ -31,16 +31,23 @@ def __init__( self._timer = None self._subscription_id = None self._listen_key = None # Used for non spot stream types + self._uses_ws_api_subscription = False # True when using ws_api async def __aexit__(self, *args, **kwargs): - if not self._path: - return if self._timer: self._timer.cancel() self._timer = None # Clean up subscription if it exists if self._subscription_id is not None: + # Unregister the queue from ws_api before unsubscribing + if hasattr(self._client, 'ws_api') and self._client.ws_api: + self._client.ws_api.unregister_subscription_queue(self._subscription_id) await self._unsubscribe_from_user_data_stream() + if self._uses_ws_api_subscription: + # For ws_api subscriptions, we don't manage the connection + return + if not self._path: + return await super().__aexit__(*args, **kwargs) def _build_path(self): @@ -51,16 +58,43 @@ def _build_path(self): async def _before_connect(self): if self._keepalive_type == "user": + # Subscribe via ws_api and register our own queue for events self._subscription_id = await self._subscribe_to_user_data_stream() - # Reuse the ws_api connection that's already established - self.ws = self._client.ws_api.ws - self.ws_state = self._client.ws_api.ws_state - self._queue = self._client.ws_api._queue + self._uses_ws_api_subscription = True + # Register our queue with ws_api so events get routed to us + self._client.ws_api.register_subscription_queue(self._subscription_id, self._queue) + self._path = f"user_subscription:{self._subscription_id}" return if not self._listen_key: self._listen_key = await self._get_listen_key() self._build_path() + async def connect(self): + """Override connect to handle ws_api subscriptions differently.""" + if self._keepalive_type == "user": + # For user sockets using ws_api subscription: + # - Subscribe via ws_api (done in _before_connect) + # - Don't create our own websocket connection + # - Don't start a read loop (ws_api handles reading) + await self._before_connect() + await self._after_connect() + return + # For other keepalive types, use normal connection logic + await super().connect() + + async def recv(self): + """Override recv to work without a read loop for ws_api subscriptions.""" + if self._uses_ws_api_subscription: + # For ws_api subscriptions, just read from queue + res = None + while not res: + try: + res = await asyncio.wait_for(self._queue.get(), timeout=self.TIMEOUT) + except asyncio.TimeoutError: + self._log.debug(f"no message in {self.TIMEOUT} seconds") + return res + return await super().recv() + async def _after_connect(self): if self._timer is None: self._start_socket_timer() diff --git a/binance/ws/reconnecting_websocket.py b/binance/ws/reconnecting_websocket.py index b39cb2a4..36277956 100644 --- a/binance/ws/reconnecting_websocket.py +++ b/binance/ws/reconnecting_websocket.py @@ -15,9 +15,9 @@ pass try: - from websockets.exceptions import ConnectionClosedError # type: ignore + from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore except ImportError: - from websockets import ConnectionClosedError # type: ignore + from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore Proxy = None @@ -226,6 +226,7 @@ async def _read_loop(self): asyncio.IncompleteReadError, gaierror, ConnectionClosedError, + ConnectionClosedOK, BinanceWebsocketClosed, ) as e: # reports errors and continue loop diff --git a/binance/ws/websocket_api.py b/binance/ws/websocket_api.py index 333d279c..3ef9ed13 100644 --- a/binance/ws/websocket_api.py +++ b/binance/ws/websocket_api.py @@ -14,8 +14,18 @@ def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_prox self._testnet = testnet self._responses: Dict[str, asyncio.Future] = {} self._connection_lock: Optional[asyncio.Lock] = None + # Subscription queues for routing user data stream events + self._subscription_queues: Dict[str, asyncio.Queue] = {} super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy) + def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None: + """Register a queue to receive events for a specific subscription.""" + self._subscription_queues[subscription_id] = queue + + def unregister_subscription_queue(self, subscription_id: str) -> None: + """Unregister a subscription queue.""" + self._subscription_queues.pop(subscription_id, None) + @property def connection_lock(self) -> asyncio.Lock: if self._connection_lock is None: @@ -33,7 +43,21 @@ def _handle_message(self, msg): # Check if this is a subscription event (user data stream, etc.) # These have 'subscriptionId' and 'event' fields instead of 'id' if "subscriptionId" in parsed_msg and "event" in parsed_msg: - return parsed_msg["event"] + subscription_id = parsed_msg["subscriptionId"] + event = parsed_msg["event"] + # Route to the registered subscription queue if one exists + if subscription_id in self._subscription_queues: + queue = self._subscription_queues[subscription_id] + try: + queue.put_nowait(event) + except asyncio.QueueFull: + self._log.error(f"Subscription queue full for {subscription_id}, dropping event") + except Exception as e: + self._log.error(f"Error putting event in subscription queue for {subscription_id}: {e}") + return None # Don't put in main queue + else: + # No registered queue, return event for main queue (backward compat) + return event req_id, exception = None, None if "id" in parsed_msg: diff --git a/tests/test_user_socket_integration.py b/tests/test_user_socket_integration.py new file mode 100644 index 00000000..b66a9189 --- /dev/null +++ b/tests/test_user_socket_integration.py @@ -0,0 +1,191 @@ +""" +Integration tests for user socket with ws_api subscription routing. + +These tests verify that the user socket correctly: +1. Uses ws_api for subscription (not creating its own connection) +2. Has its own queue for receiving events (not sharing ws_api's queue) +3. Does not start its own read loop (ws_api handles reading) +4. Properly cleans up subscriptions on exit + +Requirements: +- Binance testnet API credentials (configured in conftest.py) +- Network connectivity to testnet + +Run with: pytest tests/test_user_socket_integration.py -v +""" +import asyncio +import pytest +import pytest_asyncio + +from binance import BinanceSocketManager + + +@pytest_asyncio.fixture +async def socket_manager(clientAsync): + """Create a BinanceSocketManager using the clientAsync fixture from conftest.""" + return BinanceSocketManager(clientAsync) + + +class TestUserSocketArchitecture: + """Tests verifying the user socket architecture is correct.""" + + @pytest.mark.asyncio + async def test_user_socket_has_separate_queue(self, clientAsync, socket_manager): + """User socket should have its own queue, not share ws_api's queue.""" + user_socket = socket_manager.user_socket() + + async with user_socket: + # Queues should be different objects + assert user_socket._queue is not clientAsync.ws_api._queue, \ + "user_socket should have its own queue, not share ws_api's queue" + + @pytest.mark.asyncio + async def test_user_socket_uses_ws_api_subscription(self, clientAsync, socket_manager): + """User socket should use ws_api subscription mechanism.""" + user_socket = socket_manager.user_socket() + + async with user_socket: + # Should be marked as using ws_api subscription + assert user_socket._uses_ws_api_subscription is True, \ + "user_socket should be marked as using ws_api subscription" + + # Should have a subscription ID + assert user_socket._subscription_id is not None, \ + "user_socket should have a subscription ID" + + @pytest.mark.asyncio + async def test_user_socket_no_read_loop(self, clientAsync, socket_manager): + """User socket should NOT have its own read loop (ws_api handles reading).""" + user_socket = socket_manager.user_socket() + + async with user_socket: + # user_socket should not have started its own read loop + assert user_socket._handle_read_loop is None, \ + "user_socket should not have its own read loop" + + # ws_api should have a read loop + assert clientAsync.ws_api._handle_read_loop is not None, \ + "ws_api should have a read loop" + + @pytest.mark.asyncio + async def test_user_socket_queue_registered_with_ws_api(self, clientAsync, socket_manager): + """User socket's queue should be registered with ws_api for event routing.""" + user_socket = socket_manager.user_socket() + + async with user_socket: + sub_id = user_socket._subscription_id + + # Subscription should be registered in ws_api + assert sub_id in clientAsync.ws_api._subscription_queues, \ + "Subscription should be registered with ws_api" + + # Registered queue should be user_socket's queue + registered_queue = clientAsync.ws_api._subscription_queues[sub_id] + assert registered_queue is user_socket._queue, \ + "Registered queue should be user_socket's queue" + + @pytest.mark.asyncio + async def test_user_socket_cleanup_on_exit(self, clientAsync, socket_manager): + """User socket should unregister from ws_api on exit.""" + user_socket = socket_manager.user_socket() + + async with user_socket: + sub_id = user_socket._subscription_id + # Verify it's registered while connected + assert sub_id in clientAsync.ws_api._subscription_queues + + # After exit, subscription should be unregistered + assert sub_id not in clientAsync.ws_api._subscription_queues, \ + "Subscription should be unregistered after exit" + + +class TestUserSocketFunctionality: + """Tests verifying user socket functionality works correctly.""" + + @pytest.mark.asyncio + async def test_user_socket_recv_timeout(self, clientAsync, socket_manager): + """User socket recv() should timeout gracefully when no events.""" + user_socket = socket_manager.user_socket() + + async with user_socket: + # recv() should timeout without errors (no events on quiet account) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(user_socket.recv(), timeout=2) + + @pytest.mark.asyncio + async def test_user_socket_context_manager(self, clientAsync, socket_manager): + """User socket should work as async context manager.""" + user_socket = socket_manager.user_socket() + + # Should not be connected initially + assert user_socket._subscription_id is None + + async with user_socket: + # Should be connected inside context + assert user_socket._subscription_id is not None + assert user_socket._uses_ws_api_subscription is True + + # Subscription ID is cleared after unsubscribe + assert user_socket._subscription_id is None + + +class TestNonUserSockets: + """Tests verifying other socket types still work normally.""" + + @pytest.mark.asyncio + async def test_margin_socket_not_using_ws_api_subscription(self, clientAsync, socket_manager): + """Non-user KeepAliveWebsockets (like margin socket) should not use ws_api subscription.""" + # margin_socket is a KeepAliveWebsocket with keepalive_type="margin" + # Create it but don't connect - just check the flag + margin_socket = socket_manager.margin_socket() + + # Before connecting, the flag should be False (default) + assert margin_socket._uses_ws_api_subscription is False, \ + "Margin socket should not use ws_api subscription" + + # The _keepalive_type should be "margin", not "user" + assert margin_socket._keepalive_type == "margin" + + +class TestWsApiSubscriptionRouting: + """Tests verifying ws_api correctly routes subscription events.""" + + @pytest.mark.asyncio + async def test_ws_api_has_subscription_queues(self, clientAsync): + """ws_api should have subscription queues dict.""" + # Ensure ws_api is initialized + await clientAsync.ws_api._ensure_ws_connection() + + assert hasattr(clientAsync.ws_api, '_subscription_queues'), \ + "ws_api should have _subscription_queues attribute" + assert isinstance(clientAsync.ws_api._subscription_queues, dict), \ + "_subscription_queues should be a dict" + + @pytest.mark.asyncio + async def test_ws_api_register_unregister_queue(self, clientAsync): + """ws_api should be able to register and unregister queues.""" + await clientAsync.ws_api._ensure_ws_connection() + + test_queue = asyncio.Queue() + test_sub_id = "test_subscription_123" + + # Register + clientAsync.ws_api.register_subscription_queue(test_sub_id, test_queue) + assert test_sub_id in clientAsync.ws_api._subscription_queues + assert clientAsync.ws_api._subscription_queues[test_sub_id] is test_queue + + # Unregister + clientAsync.ws_api.unregister_subscription_queue(test_sub_id) + assert test_sub_id not in clientAsync.ws_api._subscription_queues + + @pytest.mark.asyncio + async def test_ws_api_unregister_nonexistent_is_safe(self, clientAsync): + """Unregistering a non-existent subscription should not raise.""" + await clientAsync.ws_api._ensure_ws_connection() + + # Should not raise + clientAsync.ws_api.unregister_subscription_queue("nonexistent_sub_id") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])