diff --git a/apps/scan.py b/apps/scan.py index 5761190f..0863e991 100644 --- a/apps/scan.py +++ b/apps/scan.py @@ -22,7 +22,7 @@ import bumble.logging from bumble import data_types from bumble.colors import color -from bumble.device import Advertisement, Device +from bumble.device import Advertisement, Device, DeviceConfiguration from bumble.hci import HCI_LE_1M_PHY, HCI_LE_CODED_PHY, Address, HCI_Constant from bumble.keys import JsonKeyStore from bumble.smp import AddressResolver @@ -144,8 +144,14 @@ async def scan( device_config, hci_source, hci_sink ) else: - device = Device.with_hci( - 'Bumble', 'F0:F1:F2:F3:F4:F5', hci_source, hci_sink + device = Device.from_config_with_hci( + DeviceConfiguration( + name='Bumble', + address=Address('F0:F1:F2:F3:F4:F5'), + keystore='JsonKeyStore', + ), + hci_source, + hci_sink, ) await device.power_on() diff --git a/bumble/device.py b/bumble/device.py index ffaaa7b1..46a5ffcc 100644 --- a/bumble/device.py +++ b/bumble/device.py @@ -3737,284 +3737,168 @@ async def set_connectable(self, connectable: bool = True) -> None: page_scan_enabled=self.connectable, ) - async def connect( + async def connect_le( self, peer_address: hci.Address | str, - transport: core.PhysicalTransport = PhysicalTransport.LE, connection_parameters_preferences: ( dict[hci.Phy, ConnectionParametersPreferences] | None ) = None, own_address_type: hci.OwnAddressType = hci.OwnAddressType.RANDOM, timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT, - always_resolve: bool = False, ) -> Connection: - ''' - Request a connection to a peer. - - When the transport is BLE, this method cannot be called if there is already a - pending connection. - - Args: - peer_address: - hci.Address or name of the device to connect to. - If a string is passed: - If the string is an address followed by a `@` suffix, the `always_resolve` - argument is implicitly set to True, so the connection is made to the - address after resolution. - If the string is any other address, the connection is made to that - address (with or without address resolution, depending on the - `always_resolve` argument). - For any other string, a scan for devices using that string as their name - is initiated, and a connection to the first matching device's address - is made. In that case, `always_resolve` is ignored. - - connection_parameters_preferences: - (BLE only, ignored for BR/EDR) - * None: use the 1M PHY with default parameters - * map: each entry has a PHY as key and a ConnectionParametersPreferences - object as value - - own_address_type: - (BLE only, ignored for BR/EDR) - hci.OwnAddressType.RANDOM to use this device's random address, or - hci.OwnAddressType.PUBLIC to use this device's public address. - - timeout: - Maximum time to wait for a connection to be established, in seconds. - Pass None for an unlimited time. - - always_resolve: - (BLE only, ignored for BR/EDR) - If True, always initiate a scan, resolving addresses, and connect to the - address that resolves to `peer_address`. - ''' - - # Check parameters - if transport not in (PhysicalTransport.LE, PhysicalTransport.BR_EDR): - raise InvalidArgumentError('invalid transport') - transport = core.PhysicalTransport(transport) - - # Adjust the transport automatically if we need to - if transport == PhysicalTransport.LE and not self.le_enabled: - transport = PhysicalTransport.BR_EDR - elif transport == PhysicalTransport.BR_EDR and not self.classic_enabled: - transport = PhysicalTransport.LE - # Check that there isn't already a pending connection - if transport == PhysicalTransport.LE and self.is_le_connecting: + if self.is_le_connecting: raise InvalidStateError('connection already pending') + try_resolve = not self.address_resolution_offload if isinstance(peer_address, str): try: - if transport == PhysicalTransport.LE and peer_address.endswith('@'): - peer_address = hci.Address.from_string_for_transport( - peer_address[:-1], transport - ) - always_resolve = True - logger.debug('forcing address resolution') - else: - peer_address = hci.Address.from_string_for_transport( - peer_address, transport - ) + peer_address = hci.Address.from_string_for_transport( + peer_address, PhysicalTransport.LE + ) except (InvalidArgumentError, ValueError): # If the address is not parsable, assume it is a name instead - always_resolve = False logger.debug('looking for peer by name') assert isinstance(peer_address, str) peer_address = await self.find_peer_by_name( - peer_address, transport + peer_address, PhysicalTransport.LE ) # TODO: timeout - else: - # All BR/EDR addresses should be public addresses - if ( - transport == PhysicalTransport.BR_EDR - and peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS - ): - raise InvalidArgumentError('BR/EDR addresses must be PUBLIC') + try_resolve = False assert isinstance(peer_address, hci.Address) - if transport == PhysicalTransport.LE and always_resolve: - logger.debug('resolving address') + if ( + try_resolve + and self.address_resolver is not None + and self.address_resolver.can_resolve_to(peer_address) + ): + # If we have an IRK for this address, we should resolve. + logger.debug('have IRK for address, resolving...') peer_address = await self.find_peer_by_identity_address( peer_address ) # TODO: timeout def on_connection(connection): - if transport == PhysicalTransport.LE or ( - # match BR/EDR connection event against peer address - connection.transport == transport - and connection.peer_address == peer_address - ): - pending_connection.set_result(connection) + pending_connection.set_result(connection) def on_connection_failure(error: core.ConnectionError): - if transport == PhysicalTransport.LE or ( - # match BR/EDR connection failure event against peer address - error.transport == transport - and error.peer_address == peer_address - ): - pending_connection.set_exception(error) + pending_connection.set_exception(error) - # Create a future so that we can wait for the connection's result + # Create a future so that we can wait for the connection result pending_connection = asyncio.get_running_loop().create_future() self.on(self.EVENT_CONNECTION, on_connection) self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure) try: # Tell the controller to connect - if transport == PhysicalTransport.LE: - if connection_parameters_preferences is None: - if connection_parameters_preferences is None: - connection_parameters_preferences = { - hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default - } + if connection_parameters_preferences is None: + connection_parameters_preferences = { + hci.HCI_LE_1M_PHY: ConnectionParametersPreferences.default + } - self.connect_own_address_type = own_address_type + self.connect_own_address_type = own_address_type - if self.host.supports_command( - hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND - ): - # Only keep supported PHYs - phys = sorted( - list( - set( - filter( - self.supports_le_phy, - connection_parameters_preferences.keys(), - ) + if self.host.supports_command( + hci.HCI_LE_EXTENDED_CREATE_CONNECTION_COMMAND + ): + # Only keep supported PHYs + phys = sorted( + list( + set( + filter( + self.supports_le_phy, + connection_parameters_preferences.keys(), ) ) ) - if not phys: - raise InvalidArgumentError('at least one supported PHY needed') - - phy_count = len(phys) - initiating_phys = hci.phy_list_to_bits(phys) - - connection_interval_mins = [ - int( - connection_parameters_preferences[ - phy - ].connection_interval_min - / 1.25 - ) - for phy in phys - ] - connection_interval_maxs = [ - int( - connection_parameters_preferences[ - phy - ].connection_interval_max - / 1.25 - ) - for phy in phys - ] - max_latencies = [ - connection_parameters_preferences[phy].max_latency - for phy in phys - ] - supervision_timeouts = [ - int( - connection_parameters_preferences[phy].supervision_timeout - / 10 - ) - for phy in phys - ] - min_ce_lengths = [ - int( - connection_parameters_preferences[phy].min_ce_length / 0.625 - ) - for phy in phys - ] - max_ce_lengths = [ - int( - connection_parameters_preferences[phy].max_ce_length / 0.625 - ) - for phy in phys - ] + ) + if not phys: + raise InvalidArgumentError('at least one supported PHY needed') - await self.send_async_command( - hci.HCI_LE_Extended_Create_Connection_Command( - initiator_filter_policy=0, - own_address_type=own_address_type, - peer_address_type=peer_address.address_type, - peer_address=peer_address, - initiating_phys=initiating_phys, - scan_intervals=( - int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625), - ) - * phy_count, - scan_windows=( - int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625), - ) - * phy_count, - connection_interval_mins=connection_interval_mins, - connection_interval_maxs=connection_interval_maxs, - max_latencies=max_latencies, - supervision_timeouts=supervision_timeouts, - min_ce_lengths=min_ce_lengths, - max_ce_lengths=max_ce_lengths, - ) + phy_count = len(phys) + initiating_phys = hci.phy_list_to_bits(phys) + + connection_interval_mins = [ + int( + connection_parameters_preferences[phy].connection_interval_min + / 1.25 ) - else: - if hci.HCI_LE_1M_PHY not in connection_parameters_preferences: - raise InvalidArgumentError('1M PHY preferences required') - - prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY] - await self.send_async_command( - hci.HCI_LE_Create_Connection_Command( - le_scan_interval=int( - DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625 - ), - le_scan_window=int( - DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625 - ), - initiator_filter_policy=0, - peer_address_type=peer_address.address_type, - peer_address=peer_address, - own_address_type=own_address_type, - connection_interval_min=int( - prefs.connection_interval_min / 1.25 - ), - connection_interval_max=int( - prefs.connection_interval_max / 1.25 - ), - max_latency=prefs.max_latency, - supervision_timeout=int(prefs.supervision_timeout / 10), - min_ce_length=int(prefs.min_ce_length / 0.625), - max_ce_length=int(prefs.max_ce_length / 0.625), + for phy in phys + ] + connection_interval_maxs = [ + int( + connection_parameters_preferences[phy].connection_interval_max + / 1.25 + ) + for phy in phys + ] + max_latencies = [ + connection_parameters_preferences[phy].max_latency for phy in phys + ] + supervision_timeouts = [ + int(connection_parameters_preferences[phy].supervision_timeout / 10) + for phy in phys + ] + min_ce_lengths = [ + int(connection_parameters_preferences[phy].min_ce_length / 0.625) + for phy in phys + ] + max_ce_lengths = [ + int(connection_parameters_preferences[phy].max_ce_length / 0.625) + for phy in phys + ] + + await self.send_async_command( + hci.HCI_LE_Extended_Create_Connection_Command( + initiator_filter_policy=0, + own_address_type=own_address_type, + peer_address_type=peer_address.address_type, + peer_address=peer_address, + initiating_phys=initiating_phys, + scan_intervals=( + int(DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625), ) + * phy_count, + scan_windows=(int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625),) + * phy_count, + connection_interval_mins=connection_interval_mins, + connection_interval_maxs=connection_interval_maxs, + max_latencies=max_latencies, + supervision_timeouts=supervision_timeouts, + min_ce_lengths=min_ce_lengths, + max_ce_lengths=max_ce_lengths, ) - else: - # Save pending connection - self.pending_connections[peer_address] = Connection( - device=self, - handle=0, - transport=core.PhysicalTransport.BR_EDR, - self_address=self.public_address, - self_resolvable_address=None, - peer_address=peer_address, - peer_resolvable_address=None, - role=hci.Role.CENTRAL, - parameters=Connection.Parameters(0, 0, 0), ) + else: + if hci.HCI_LE_1M_PHY not in connection_parameters_preferences: + raise InvalidArgumentError('1M PHY preferences required') - # TODO: allow passing other settings + prefs = connection_parameters_preferences[hci.HCI_LE_1M_PHY] await self.send_async_command( - hci.HCI_Create_Connection_Command( - bd_addr=peer_address, - packet_type=0xCC18, # FIXME: change - page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE, - clock_offset=0x0000, - allow_role_switch=0x01, - reserved=0, + hci.HCI_LE_Create_Connection_Command( + le_scan_interval=int( + DEVICE_DEFAULT_CONNECT_SCAN_INTERVAL / 0.625 + ), + le_scan_window=int(DEVICE_DEFAULT_CONNECT_SCAN_WINDOW / 0.625), + initiator_filter_policy=0, + peer_address_type=peer_address.address_type, + peer_address=peer_address, + own_address_type=own_address_type, + connection_interval_min=int( + prefs.connection_interval_min / 1.25 + ), + connection_interval_max=int( + prefs.connection_interval_max / 1.25 + ), + max_latency=prefs.max_latency, + supervision_timeout=int(prefs.supervision_timeout / 10), + min_ce_length=int(prefs.min_ce_length / 0.625), + max_ce_length=int(prefs.max_ce_length / 0.625), ) ) # Wait for the connection process to complete - if transport == PhysicalTransport.LE: - self.le_connecting = True + self.le_connecting = True if timeout is None: return await utils.cancel_on_event( @@ -4026,14 +3910,107 @@ def on_connection_failure(error: core.ConnectionError): asyncio.shield(pending_connection), timeout ) except asyncio.TimeoutError: - if transport == PhysicalTransport.LE: - await self.send_sync_command( - hci.HCI_LE_Create_Connection_Cancel_Command() - ) - else: - await self.send_sync_command( - hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) + await self.send_sync_command( + hci.HCI_LE_Create_Connection_Cancel_Command() + ) + + try: + return await utils.cancel_on_event( + self, Device.EVENT_FLUSH, pending_connection ) + except core.ConnectionError as error: + raise core.TimeoutError() from error + finally: + self.remove_listener(self.EVENT_CONNECTION, on_connection) + self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure) + self.le_connecting = False + self.connect_own_address_type = None + + async def connect_classic( + self, + peer_address: hci.Address | str, + timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT, + ) -> Connection: + if isinstance(peer_address, str): + try: + peer_address = hci.Address.from_string_for_transport( + peer_address, PhysicalTransport.BR_EDR + ) + except (InvalidArgumentError, ValueError): + # If the address is not parsable, assume it is a name instead + logger.debug('looking for peer by name') + assert isinstance(peer_address, str) + peer_address = await self.find_peer_by_name( + peer_address, PhysicalTransport.BR_EDR + ) # TODO: timeout + else: + # All BR/EDR addresses should be public addresses + if peer_address.address_type != hci.Address.PUBLIC_DEVICE_ADDRESS: + raise InvalidArgumentError('BR/EDR addresses must be PUBLIC') + + assert isinstance(peer_address, hci.Address) + + def on_connection(connection): + if ( + # match BR/EDR connection event against peer address + connection.transport == PhysicalTransport.BR_EDR + and connection.peer_address == peer_address + ): + pending_connection.set_result(connection) + + def on_connection_failure(error: core.ConnectionError): + if ( + # match BR/EDR connection failure event against peer address + error.transport == PhysicalTransport.BR_EDR + and error.peer_address == peer_address + ): + pending_connection.set_exception(error) + + # Create a future so that we can wait for the connection result + pending_connection = asyncio.get_running_loop().create_future() + self.on(self.EVENT_CONNECTION, on_connection) + self.on(self.EVENT_CONNECTION_FAILURE, on_connection_failure) + + try: + # Save pending connection + self.pending_connections[peer_address] = Connection( + device=self, + handle=0, + transport=core.PhysicalTransport.BR_EDR, + self_address=self.public_address, + self_resolvable_address=None, + peer_address=peer_address, + peer_resolvable_address=None, + role=hci.Role.CENTRAL, + parameters=Connection.Parameters(0, 0, 0), + ) + + # TODO: allow passing other settings + await self.send_async_command( + hci.HCI_Create_Connection_Command( + bd_addr=peer_address, + packet_type=0xCC18, # FIXME: change + page_scan_repetition_mode=hci.HCI_R2_PAGE_SCAN_REPETITION_MODE, + clock_offset=0x0000, + allow_role_switch=0x01, + reserved=0, + ) + ) + + # Wait for the connection process to complete + if timeout is None: + return await utils.cancel_on_event( + self, Device.EVENT_FLUSH, pending_connection + ) + + try: + return await asyncio.wait_for( + asyncio.shield(pending_connection), timeout + ) + except asyncio.TimeoutError: + await self.send_sync_command( + hci.HCI_Create_Connection_Cancel_Command(bd_addr=peer_address) + ) try: return await utils.cancel_on_event( @@ -4044,11 +4021,78 @@ def on_connection_failure(error: core.ConnectionError): finally: self.remove_listener(self.EVENT_CONNECTION, on_connection) self.remove_listener(self.EVENT_CONNECTION_FAILURE, on_connection_failure) - if transport == PhysicalTransport.LE: - self.le_connecting = False - self.connect_own_address_type = None - else: - self.pending_connections.pop(peer_address, None) + self.pending_connections.pop(peer_address, None) + + async def connect( + self, + peer_address: hci.Address | str, + transport: core.PhysicalTransport = PhysicalTransport.LE, + connection_parameters_preferences: ( + dict[hci.Phy, ConnectionParametersPreferences] | None + ) = None, + own_address_type: hci.OwnAddressType = hci.OwnAddressType.RANDOM, + timeout: float | None = DEVICE_DEFAULT_CONNECT_TIMEOUT, + always_resolve: bool = False, + ) -> Connection: + ''' + Request a connection to a peer. + + When the transport is BLE, this method cannot be called if there is already a + pending connection. + + Args: + peer_address: + hci.Address or name of the device to connect to. + If a string is passed: + [deprecated] If the string is an address followed by a `@` suffix, the + `always_resolve`argument is implicitly set to True, so the connection is + made to the address after resolution. + If the string is any other address, the connection is made to that + address (with or without address resolution, depending on the + `always_resolve` argument). + For any other string, a scan for devices using that string as their name + is initiated, and a connection to the first matching device's address + is made. In that case, `always_resolve` is ignored. + + connection_parameters_preferences: + (BLE only, ignored for BR/EDR) + * None: use the 1M PHY with default parameters + * map: each entry has a PHY as key and a ConnectionParametersPreferences + object as value + + own_address_type: + (BLE only, ignored for BR/EDR) + hci.OwnAddressType.RANDOM to use this device's random address, or + hci.OwnAddressType.PUBLIC to use this device's public address. + + timeout: + Maximum time to wait for a connection to be established, in seconds. + Pass None for an unlimited time. + + always_resolve: + [deprecated] (ignore) + ''' + + # Connect using the appropriate transport + # (auto-correct the transport based on declared capabilities) + if transport == PhysicalTransport.LE or ( + self.le_enabled and not self.classic_enabled + ): + return await self.connect_le( + peer_address=peer_address, + connection_parameters_preferences=connection_parameters_preferences, + own_address_type=own_address_type, + timeout=timeout, + ) + + if transport == PhysicalTransport.BR_EDR or ( + self.classic_enabled and not self.le_enabled + ): + return await self.connect_classic( + peer_address=peer_address, timeout=timeout + ) + + raise ValueError('invalid transport') async def accept( self, @@ -4695,6 +4739,8 @@ async def find_peer_by_identity_address( Scan for a peer with a resolvable address that can be resolved to a given identity address. """ + if self.address_resolver is None: + raise InvalidStateError('no resolver') # Create a future to wait for an address to be found peer_address = asyncio.get_running_loop().create_future() diff --git a/bumble/host.py b/bumble/host.py index ed8ec7eb..57d8e5d2 100644 --- a/bumble/host.py +++ b/bumble/host.py @@ -803,7 +803,9 @@ def send_acl_sdu(self, connection_handle: int, sdu: bytes) -> None: data=pdu, ) logger.debug( - '>>> ACL packet enqueue: (Handle=0x%04X) %s', connection_handle, pdu + '>>> ACL packet enqueue: (handle=0x%04X) %s', + connection_handle, + pdu.hex(), ) packet_queue.enqueue(acl_packet, connection_handle) diff --git a/bumble/smp.py b/bumble/smp.py index c27e9e7d..76b4d00b 100644 --- a/bumble/smp.py +++ b/bumble/smp.py @@ -27,7 +27,7 @@ import asyncio import enum import logging -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, ClassVar, TypeVar, cast @@ -507,10 +507,15 @@ def smp_auth_req(bonding: bool, mitm: bool, sc: bool, keypress: bool, ct2: bool) # ----------------------------------------------------------------------------- class AddressResolver: - def __init__(self, resolving_keys): + def __init__(self, resolving_keys: Sequence[tuple[bytes, Address]]) -> None: self.resolving_keys = resolving_keys - def resolve(self, address): + def can_resolve_to(self, address: Address) -> bool: + return any( + resolved_address == address for _, resolved_address in self.resolving_keys + ) + + def resolve(self, address: Address) -> Address | None: address_bytes = bytes(address) hash_part = address_bytes[0:3] prand = address_bytes[3:6]