Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions binance/ws/keepalive_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,23 @@ def __init__(
self._timer = None
self._subscription_id = None
self._listen_key = None # Used for non spot stream types
self._uses_ws_api_subscription = False # True when using ws_api

async def __aexit__(self, *args, **kwargs):
if not self._path:
return
if self._timer:
self._timer.cancel()
self._timer = None
# Clean up subscription if it exists
if self._subscription_id is not None:
# Unregister the queue from ws_api before unsubscribing
if hasattr(self._client, 'ws_api') and self._client.ws_api:
self._client.ws_api.unregister_subscription_queue(self._subscription_id)
await self._unsubscribe_from_user_data_stream()
if self._uses_ws_api_subscription:
# For ws_api subscriptions, we don't manage the connection
return
if not self._path:
return
await super().__aexit__(*args, **kwargs)

def _build_path(self):
Expand All @@ -51,16 +58,43 @@ def _build_path(self):

async def _before_connect(self):
if self._keepalive_type == "user":
# Subscribe via ws_api and register our own queue for events
self._subscription_id = await self._subscribe_to_user_data_stream()
# Reuse the ws_api connection that's already established
self.ws = self._client.ws_api.ws
self.ws_state = self._client.ws_api.ws_state
self._queue = self._client.ws_api._queue
self._uses_ws_api_subscription = True
# Register our queue with ws_api so events get routed to us
self._client.ws_api.register_subscription_queue(self._subscription_id, self._queue)
self._path = f"user_subscription:{self._subscription_id}"
return
if not self._listen_key:
self._listen_key = await self._get_listen_key()
self._build_path()

async def connect(self):
"""Override connect to handle ws_api subscriptions differently."""
if self._keepalive_type == "user":
# For user sockets using ws_api subscription:
# - Subscribe via ws_api (done in _before_connect)
# - Don't create our own websocket connection
# - Don't start a read loop (ws_api handles reading)
await self._before_connect()
await self._after_connect()
return
# For other keepalive types, use normal connection logic
await super().connect()

async def recv(self):
"""Override recv to work without a read loop for ws_api subscriptions."""
if self._uses_ws_api_subscription:
# For ws_api subscriptions, just read from queue
res = None
while not res:
try:
res = await asyncio.wait_for(self._queue.get(), timeout=self.TIMEOUT)
except asyncio.TimeoutError:
self._log.debug(f"no message in {self.TIMEOUT} seconds")
return res
return await super().recv()

async def _after_connect(self):
if self._timer is None:
self._start_socket_timer()
Expand Down
5 changes: 3 additions & 2 deletions binance/ws/reconnecting_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
pass

try:
from websockets.exceptions import ConnectionClosedError # type: ignore
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK # type: ignore
except ImportError:
from websockets import ConnectionClosedError # type: ignore
from websockets import ConnectionClosedError, ConnectionClosedOK # type: ignore


Proxy = None
Expand Down Expand Up @@ -226,6 +226,7 @@ async def _read_loop(self):
asyncio.IncompleteReadError,
gaierror,
ConnectionClosedError,
ConnectionClosedOK,
BinanceWebsocketClosed,
) as e:
# reports errors and continue loop
Expand Down
26 changes: 25 additions & 1 deletion binance/ws/websocket_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,18 @@ def __init__(self, url: str, tld: str = "com", testnet: bool = False, https_prox
self._testnet = testnet
self._responses: Dict[str, asyncio.Future] = {}
self._connection_lock: Optional[asyncio.Lock] = None
# Subscription queues for routing user data stream events
self._subscription_queues: Dict[str, asyncio.Queue] = {}
super().__init__(url=url, prefix="", path="", is_binary=False, https_proxy=https_proxy)

def register_subscription_queue(self, subscription_id: str, queue: asyncio.Queue) -> None:
"""Register a queue to receive events for a specific subscription."""
self._subscription_queues[subscription_id] = queue

def unregister_subscription_queue(self, subscription_id: str) -> None:
"""Unregister a subscription queue."""
self._subscription_queues.pop(subscription_id, None)

@property
def connection_lock(self) -> asyncio.Lock:
if self._connection_lock is None:
Expand All @@ -33,7 +43,21 @@ def _handle_message(self, msg):
# Check if this is a subscription event (user data stream, etc.)
# These have 'subscriptionId' and 'event' fields instead of 'id'
if "subscriptionId" in parsed_msg and "event" in parsed_msg:
return parsed_msg["event"]
subscription_id = parsed_msg["subscriptionId"]
event = parsed_msg["event"]
# Route to the registered subscription queue if one exists
if subscription_id in self._subscription_queues:
queue = self._subscription_queues[subscription_id]
try:
queue.put_nowait(event)
except asyncio.QueueFull:
self._log.error(f"Subscription queue full for {subscription_id}, dropping event")
except Exception as e:
self._log.error(f"Error putting event in subscription queue for {subscription_id}: {e}")
return None # Don't put in main queue
else:
# No registered queue, return event for main queue (backward compat)
return event

req_id, exception = None, None
if "id" in parsed_msg:
Expand Down
191 changes: 191 additions & 0 deletions tests/test_user_socket_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""
Integration tests for user socket with ws_api subscription routing.

These tests verify that the user socket correctly:
1. Uses ws_api for subscription (not creating its own connection)
2. Has its own queue for receiving events (not sharing ws_api's queue)
3. Does not start its own read loop (ws_api handles reading)
4. Properly cleans up subscriptions on exit

Requirements:
- Binance testnet API credentials (configured in conftest.py)
- Network connectivity to testnet

Run with: pytest tests/test_user_socket_integration.py -v
"""
import asyncio
import pytest
import pytest_asyncio

from binance import BinanceSocketManager


@pytest_asyncio.fixture
async def socket_manager(clientAsync):
"""Create a BinanceSocketManager using the clientAsync fixture from conftest."""
return BinanceSocketManager(clientAsync)


class TestUserSocketArchitecture:
"""Tests verifying the user socket architecture is correct."""

@pytest.mark.asyncio
async def test_user_socket_has_separate_queue(self, clientAsync, socket_manager):
"""User socket should have its own queue, not share ws_api's queue."""
user_socket = socket_manager.user_socket()

async with user_socket:
# Queues should be different objects
assert user_socket._queue is not clientAsync.ws_api._queue, \
"user_socket should have its own queue, not share ws_api's queue"

@pytest.mark.asyncio
async def test_user_socket_uses_ws_api_subscription(self, clientAsync, socket_manager):
"""User socket should use ws_api subscription mechanism."""
user_socket = socket_manager.user_socket()

async with user_socket:
# Should be marked as using ws_api subscription
assert user_socket._uses_ws_api_subscription is True, \
"user_socket should be marked as using ws_api subscription"

# Should have a subscription ID
assert user_socket._subscription_id is not None, \
"user_socket should have a subscription ID"

@pytest.mark.asyncio
async def test_user_socket_no_read_loop(self, clientAsync, socket_manager):
"""User socket should NOT have its own read loop (ws_api handles reading)."""
user_socket = socket_manager.user_socket()

async with user_socket:
# user_socket should not have started its own read loop
assert user_socket._handle_read_loop is None, \
"user_socket should not have its own read loop"

# ws_api should have a read loop
assert clientAsync.ws_api._handle_read_loop is not None, \
"ws_api should have a read loop"

@pytest.mark.asyncio
async def test_user_socket_queue_registered_with_ws_api(self, clientAsync, socket_manager):
"""User socket's queue should be registered with ws_api for event routing."""
user_socket = socket_manager.user_socket()

async with user_socket:
sub_id = user_socket._subscription_id

# Subscription should be registered in ws_api
assert sub_id in clientAsync.ws_api._subscription_queues, \
"Subscription should be registered with ws_api"

# Registered queue should be user_socket's queue
registered_queue = clientAsync.ws_api._subscription_queues[sub_id]
assert registered_queue is user_socket._queue, \
"Registered queue should be user_socket's queue"

@pytest.mark.asyncio
async def test_user_socket_cleanup_on_exit(self, clientAsync, socket_manager):
"""User socket should unregister from ws_api on exit."""
user_socket = socket_manager.user_socket()

async with user_socket:
sub_id = user_socket._subscription_id
# Verify it's registered while connected
assert sub_id in clientAsync.ws_api._subscription_queues

# After exit, subscription should be unregistered
assert sub_id not in clientAsync.ws_api._subscription_queues, \
"Subscription should be unregistered after exit"


class TestUserSocketFunctionality:
"""Tests verifying user socket functionality works correctly."""

@pytest.mark.asyncio
async def test_user_socket_recv_timeout(self, clientAsync, socket_manager):
"""User socket recv() should timeout gracefully when no events."""
user_socket = socket_manager.user_socket()

async with user_socket:
# recv() should timeout without errors (no events on quiet account)
with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(user_socket.recv(), timeout=2)

@pytest.mark.asyncio
async def test_user_socket_context_manager(self, clientAsync, socket_manager):
"""User socket should work as async context manager."""
user_socket = socket_manager.user_socket()

# Should not be connected initially
assert user_socket._subscription_id is None

async with user_socket:
# Should be connected inside context
assert user_socket._subscription_id is not None
assert user_socket._uses_ws_api_subscription is True

# Subscription ID is cleared after unsubscribe
assert user_socket._subscription_id is None


class TestNonUserSockets:
"""Tests verifying other socket types still work normally."""

@pytest.mark.asyncio
async def test_margin_socket_not_using_ws_api_subscription(self, clientAsync, socket_manager):
"""Non-user KeepAliveWebsockets (like margin socket) should not use ws_api subscription."""
# margin_socket is a KeepAliveWebsocket with keepalive_type="margin"
# Create it but don't connect - just check the flag
margin_socket = socket_manager.margin_socket()

# Before connecting, the flag should be False (default)
assert margin_socket._uses_ws_api_subscription is False, \
"Margin socket should not use ws_api subscription"

# The _keepalive_type should be "margin", not "user"
assert margin_socket._keepalive_type == "margin"


class TestWsApiSubscriptionRouting:
"""Tests verifying ws_api correctly routes subscription events."""

@pytest.mark.asyncio
async def test_ws_api_has_subscription_queues(self, clientAsync):
"""ws_api should have subscription queues dict."""
# Ensure ws_api is initialized
await clientAsync.ws_api._ensure_ws_connection()

assert hasattr(clientAsync.ws_api, '_subscription_queues'), \
"ws_api should have _subscription_queues attribute"
assert isinstance(clientAsync.ws_api._subscription_queues, dict), \
"_subscription_queues should be a dict"

@pytest.mark.asyncio
async def test_ws_api_register_unregister_queue(self, clientAsync):
"""ws_api should be able to register and unregister queues."""
await clientAsync.ws_api._ensure_ws_connection()

test_queue = asyncio.Queue()
test_sub_id = "test_subscription_123"

# Register
clientAsync.ws_api.register_subscription_queue(test_sub_id, test_queue)
assert test_sub_id in clientAsync.ws_api._subscription_queues
assert clientAsync.ws_api._subscription_queues[test_sub_id] is test_queue

# Unregister
clientAsync.ws_api.unregister_subscription_queue(test_sub_id)
assert test_sub_id not in clientAsync.ws_api._subscription_queues

@pytest.mark.asyncio
async def test_ws_api_unregister_nonexistent_is_safe(self, clientAsync):
"""Unregistering a non-existent subscription should not raise."""
await clientAsync.ws_api._ensure_ws_connection()

# Should not raise
clientAsync.ws_api.unregister_subscription_queue("nonexistent_sub_id")


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading