diff --git a/src/publisher/core.py b/src/publisher/core.py index 5fa8295..6e7e926 100644 --- a/src/publisher/core.py +++ b/src/publisher/core.py @@ -1,17 +1,38 @@ from __future__ import annotations from abc import ABC, abstractmethod +from datetime import datetime import json import re from typing import TYPE_CHECKING, Any, TypeVar import mqtt_topics +from utils import datetime_to_str if TYPE_CHECKING: from configuration import Configuration T = TypeVar("T") +type Publishable = bool | int | float | str | dict[str, Any] | datetime +"""Closed union of value types this gateway knows how to publish to MQTT. + +Mirrors the typed `publish_*` methods on :class:`Publisher` plus the `dict` +shape handled by `publish_json`, and `datetime`, which is stringified via +:func:`utils.datetime_to_str`. Use it at signature boundaries when a caller +holds "something publishable" without statically knowing which arm. +""" + +type WirePayload = bool | int | float | str +"""Primitive subset of :data:`Publishable` that reaches the transport layer. + +After the typed `publish_*` methods do their work (`publish_json` serializes +dicts to JSON strings, `publish_datetime` stringifies via +:func:`utils.datetime_to_str`), only these scalar arms cross the +publisher/transport boundary. Use `WirePayload | None` for wire-level helpers +where `None` means "clear the retained message." +""" + class MqttCommandListener(ABC): @abstractmethod @@ -108,6 +129,49 @@ def publish_float( ) -> None: raise NotImplementedError + def publish_datetime( + self, + key: str, + value: datetime, + no_prefix: bool = False, + *, + retain: bool = True, + ) -> None: + """Stringify a datetime via :func:`utils.datetime_to_str` and publish.""" + self.publish_str(key, datetime_to_str(value), no_prefix, retain=retain) + + def publish( + self, + key: str, + value: Publishable, + no_prefix: bool = False, + *, + retain: bool = True, + ) -> None: + """Dispatch to the appropriate typed publish_* based on value type. + + For callers that hold a `Publishable` without statically knowing + which arm of the union it is. `retain` is forwarded to every arm. + """ + # bool must precede int: isinstance(True, int) is True in Python. + if isinstance(value, bool): + self.publish_bool(key, value, no_prefix, retain=retain) + elif isinstance(value, int): + self.publish_int(key, value, no_prefix, retain=retain) + elif isinstance(value, float): + self.publish_float(key, value, no_prefix, retain=retain) + elif isinstance(value, str): + self.publish_str(key, value, no_prefix, retain=retain) + elif isinstance(value, dict): + self.publish_json(key, value, no_prefix, retain=retain) + elif isinstance(value, datetime): + self.publish_datetime(key, value, no_prefix, retain=retain) + else: + # Defensive: type system rules this out, but `Any` callers can sneak + # an unsupported runtime type through; raise rather than silently no-op. + msg = f"Unsupported value type: {type(value).__name__}" # type: ignore[unreachable] + raise TypeError(msg) + @abstractmethod def clear_topic(self, key: str, no_prefix: bool = False) -> None: raise NotImplementedError diff --git a/src/publisher/log_publisher.py b/src/publisher/log_publisher.py index 45e30d2..7969c61 100644 --- a/src/publisher/log_publisher.py +++ b/src/publisher/log_publisher.py @@ -3,7 +3,7 @@ import logging from typing import Any, override -from publisher.core import Publisher +from publisher.core import Publisher, WirePayload LOG = logging.getLogger(__name__) LOG.setLevel(level="DEBUG") @@ -62,5 +62,7 @@ def publish_float( def clear_topic(self, key: str, no_prefix: bool = False) -> None: self.internal_publish(key, None) - def internal_publish(self, key: str, value: Any, *, retain: bool = True) -> None: + def internal_publish( + self, key: str, value: WirePayload | None, *, retain: bool = True + ) -> None: LOG.debug(f"{key}: {value} (retain={retain})") diff --git a/src/publisher/mqtt_publisher.py b/src/publisher/mqtt_publisher.py index 6d3256f..ed535d4 100644 --- a/src/publisher/mqtt_publisher.py +++ b/src/publisher/mqtt_publisher.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: from configuration import Configuration from integrations.openwb.charging_station import ChargingStation + from publisher.core import WirePayload LOG = logging.getLogger(__name__) @@ -226,7 +227,9 @@ async def __handle_imported_energy(self, topic: str, payload: str) -> None: vin, imported_energy_wh ) - def __publish(self, topic: str, payload: Any, *, retain: bool = True) -> None: + def __publish( + self, topic: str, payload: WirePayload | None, *, retain: bool = True + ) -> None: self.client.publish(topic, payload, retain=retain) @override diff --git a/src/status_publisher/__init__.py b/src/status_publisher/__init__.py index 4fffbba..387f13f 100644 --- a/src/status_publisher/__init__.py +++ b/src/status_publisher/__init__.py @@ -1,10 +1,9 @@ from __future__ import annotations from abc import ABCMeta, abstractmethod -from datetime import datetime -from typing import TYPE_CHECKING, Any, Final, TypeVar +from typing import TYPE_CHECKING, Final -from utils import datetime_to_str +from publisher.core import Publishable if TYPE_CHECKING: from collections.abc import Callable @@ -12,9 +11,6 @@ from publisher.core import Publisher from vehicle_info import VehicleInfo -T = TypeVar("T") -Publishable = TypeVar("Publishable", str, int, float, bool, dict[str, Any], datetime) - class VehicleDataPublisher[I, O](metaclass=ABCMeta): def __init__( @@ -28,65 +24,37 @@ def __init__( def publish(self, data: I) -> O: raise NotImplementedError - def _publish( + def _publish[V: Publishable]( self, *, topic: str, - value: Publishable | None, - validator: Callable[[Publishable], bool] = lambda _: True, + value: V | None, + validator: Callable[[V], bool] = lambda _: True, no_prefix: bool = False, retain: bool = True, - ) -> tuple[bool, Publishable | None]: + ) -> tuple[bool, V | None]: if value is None or not validator(value): return False, None actual_topic = topic if no_prefix else self.__get_topic(topic) - published = self._publish_directly( - topic=actual_topic, value=value, retain=retain - ) - return published, value + self.__publisher.publish(actual_topic, value, retain=retain) + return True, value - def _transform_and_publish( + def _transform_and_publish[T, V: Publishable]( self, *, topic: str, value: T | None, validator: Callable[[T], bool] = lambda _: True, - transform: Callable[[T], Publishable], + transform: Callable[[T], V], no_prefix: bool = False, retain: bool = True, - ) -> tuple[bool, Publishable | None]: + ) -> tuple[bool, V | None]: if value is None or not validator(value): return False, None actual_topic = topic if no_prefix else self.__get_topic(topic) transformed_value = transform(value) - published = self._publish_directly( - topic=actual_topic, value=transformed_value, retain=retain - ) - return published, transformed_value - - def _publish_directly( - self, *, topic: str, value: Publishable, retain: bool = True - ) -> bool: - published = False - if isinstance(value, bool): - self.__publisher.publish_bool(topic, value) - published = True - elif isinstance(value, int): - self.__publisher.publish_int(topic, value) - published = True - elif isinstance(value, float): - self.__publisher.publish_float(topic, value) - published = True - elif isinstance(value, str): - self.__publisher.publish_str(topic, value) - published = True - elif isinstance(value, dict): - self.__publisher.publish_json(topic, value, retain=retain) - published = True - elif isinstance(value, datetime): - self.__publisher.publish_str(topic, datetime_to_str(value)) - published = True - return published + self.__publisher.publish(actual_topic, transformed_value, retain=retain) + return True, transformed_value def __get_topic(self, sub_topic: str) -> str: return f"{self.__mqtt_vehicle_prefix}/{sub_topic}" diff --git a/src/vehicle.py b/src/vehicle.py index d8f7b7a..731293d 100644 --- a/src/vehicle.py +++ b/src/vehicle.py @@ -4,7 +4,7 @@ from enum import Enum, unique import logging import math -from typing import TYPE_CHECKING, Any, Final, TypeVar +from typing import TYPE_CHECKING, Final from apscheduler.triggers.cron import CronTrigger from saic_ismart_client_ng.api.vehicle_charging import ( @@ -17,6 +17,7 @@ from extractors import extract_electric_range, extract_soc import mqtt_topics +from publisher.core import Publishable from status_publisher.charge.chrg_mgmt_data_resp import ( ChrgMgmtDataRespProcessingResult, ChrgMgmtDataRespPublisher, @@ -26,7 +27,6 @@ VehicleStatusRespProcessingResult, VehicleStatusRespPublisher, ) -from utils import datetime_to_str if TYPE_CHECKING: from collections.abc import Callable @@ -42,11 +42,6 @@ from publisher.core import Publisher from vehicle_info import VehicleInfo - T = TypeVar("T") - Publishable = TypeVar( - "Publishable", str, int, float, bool, dict[str, Any], datetime.datetime - ) - DEFAULT_AC_TEMP = 22 PRESSURE_TO_BAR_FACTOR = 0.04 @@ -378,7 +373,7 @@ def notify_car_activity(self) -> None: self.last_car_activity = datetime.datetime.now(tz=datetime.UTC) self.__publish( topic=mqtt_topics.REFRESH_LAST_ACTIVITY, - value=datetime_to_str(self.last_car_activity), + value=self.last_car_activity, ) def notify_message(self, message: MessageEntity) -> None: @@ -505,8 +500,8 @@ def last_failed_refresh(self, value: datetime.datetime | None) -> None: ) ) self.__failed_refresh_counter = self.__failed_refresh_counter + 1 - self.publisher.publish_str( - self.get_topic(mqtt_topics.REFRESH_LAST_ERROR), datetime_to_str(value) + self.publisher.publish_datetime( + self.get_topic(mqtt_topics.REFRESH_LAST_ERROR), value ) self.publisher.publish_int( self.get_topic(mqtt_topics.REFRESH_PERIOD_ERROR), @@ -806,41 +801,19 @@ def update_battery_capacity(self, new_capacity: float) -> None: def is_remote_ac_running(self) -> bool: return self.__remote_ac_running - def __publish( + def __publish[V: Publishable]( self, *, topic: str, - value: Publishable | None, - validator: Callable[[Publishable], bool] = lambda _: True, + value: V | None, + validator: Callable[[V], bool] = lambda _: True, no_prefix: bool = False, - ) -> tuple[bool, Publishable | None]: + ) -> tuple[bool, V | None]: if value is None or not validator(value): return False, None actual_topic = topic if no_prefix else self.get_topic(topic) - published = self.__publish_directly(topic=actual_topic, value=value) - return published, value - - def __publish_directly(self, *, topic: str, value: Publishable) -> bool: - published = False - if isinstance(value, bool): - self.publisher.publish_bool(topic, value) - published = True - elif isinstance(value, int): - self.publisher.publish_int(topic, value) - published = True - elif isinstance(value, float): - self.publisher.publish_float(topic, value) - published = True - elif isinstance(value, str): - self.publisher.publish_str(topic, value) - published = True - elif isinstance(value, dict): - self.publisher.publish_json(topic, value) - published = True - elif isinstance(value, datetime.datetime): - self.publisher.publish_str(topic, datetime_to_str(value)) - published = True - return published + self.publisher.publish(actual_topic, value) + return True, value @property def vin(self) -> str: diff --git a/tests/mocks/__init__.py b/tests/mocks/__init__.py index de6db2d..b47d18b 100644 --- a/tests/mocks/__init__.py +++ b/tests/mocks/__init__.py @@ -7,6 +7,7 @@ if TYPE_CHECKING: from configuration import Configuration + from publisher.core import WirePayload LOG = logging.getLogger(__name__) @@ -14,11 +15,15 @@ class MessageCapturingConsolePublisher(ConsolePublisher): def __init__(self, configuration: Configuration) -> None: super().__init__(configuration) + # Test inspection map; consumers narrow per-key (e.g. json.loads on + # serialized dict topics), so keep the value type permissive here. self.map: dict[str, Any] = {} self.publish_count: dict[str, int] = {} @override - def internal_publish(self, key: str, value: Any, *, retain: bool = True) -> None: + def internal_publish( + self, key: str, value: WirePayload | None, *, retain: bool = True + ) -> None: self.map[key] = value self.publish_count[key] = self.publish_count.get(key, 0) + 1 LOG.debug(f"{key}: {value} (retain={retain})") diff --git a/tests/publisher/__init__.py b/tests/publisher/__init__.py new file mode 100644 index 0000000..9d48db4 --- /dev/null +++ b/tests/publisher/__init__.py @@ -0,0 +1 @@ +from __future__ import annotations diff --git a/tests/publisher/test_publish_dispatch.py b/tests/publisher/test_publish_dispatch.py new file mode 100644 index 0000000..7902771 --- /dev/null +++ b/tests/publisher/test_publish_dispatch.py @@ -0,0 +1,372 @@ +"""Conformance tests for `Publisher.publish` dispatch across all subclasses. + +`Publisher.publish` is a single non-abstract method on the ABC that dispatches +based on the runtime type of `value` to the corresponding typed +`publish_{bool,int,float,str,datetime,json}` method. `publish_datetime` is itself +a concrete ABC-level method that stringifies via :func:`utils.datetime_to_str` +and forwards to `publish_str`. The tests below exercise that dispatch directly +on every concrete `Publisher` subclass shipped by the project, plus a minimal +in-test subclass that locks the contract at the ABC level. + +The critical regression these tests guard against: `bool` is a subclass of +`int` in Python, so `isinstance(True, int)` is `True`. The dispatch must check +`bool` *before* `int` so that `publish(key, True)` reaches `publish_bool` (not +`publish_int`). +""" + +from __future__ import annotations + +from datetime import UTC, datetime +from typing import TYPE_CHECKING, Any, override +from unittest.mock import MagicMock, patch + +import pytest + +from configuration import Configuration, TransportProtocol +from publisher.core import Publishable, Publisher +from publisher.log_publisher import ConsolePublisher +from publisher.mqtt_publisher import MqttPublisher +from tests.mocks import MessageCapturingConsolePublisher +from utils import datetime_to_str + +if TYPE_CHECKING: + from collections.abc import Callable + + +KEY = "some/topic" + + +def _make_configuration() -> Configuration: + config = Configuration() + config.mqtt_topic = "saic" + config.saic_user = "user@example.com" + config.mqtt_transport_protocol = TransportProtocol.TCP + return config + + +# Each entry: (label, factory) where factory returns a fresh concrete Publisher. +PUBLISHER_FACTORIES: list[tuple[str, Callable[[], Publisher]]] = [ + ("MqttPublisher", lambda: MqttPublisher(_make_configuration())), + ("ConsolePublisher", lambda: ConsolePublisher(_make_configuration())), + ( + "MessageCapturingConsolePublisher", + lambda: MessageCapturingConsolePublisher(_make_configuration()), + ), +] + + +# (label, value, expected typed-method name) for arms where the value is +# forwarded to the typed method unchanged. +PASSTHROUGH_CASES: list[tuple[str, Publishable, str]] = [ + ("bool_true", True, "publish_bool"), + ("bool_false", False, "publish_bool"), + ("int_value", 5, "publish_int"), + ("int_zero", 0, "publish_int"), + ("float_value", 5.0, "publish_float"), + ("str_value", "hi", "publish_str"), + ( + "datetime_value", + datetime(2026, 5, 9, 12, 34, 56, tzinfo=UTC), + "publish_datetime", + ), +] + +TYPED_METHODS = ( + "publish_bool", + "publish_int", + "publish_float", + "publish_str", + "publish_datetime", + "publish_json", +) + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +@pytest.mark.parametrize( + ("case_label", "value", "expected_method"), + PASSTHROUGH_CASES, + ids=[label for label, _, _ in PASSTHROUGH_CASES], +) +def test_publish_dispatches_to_correct_typed_method( + publisher_label: str, + factory: Callable[[], Publisher], + case_label: str, + value: Publishable, + expected_method: str, +) -> None: + del publisher_label, case_label # only used as test ids + publisher = factory() + with ( + patch.object(publisher, "publish_bool") as m_bool, + patch.object(publisher, "publish_int") as m_int, + patch.object(publisher, "publish_float") as m_float, + patch.object(publisher, "publish_str") as m_str, + patch.object(publisher, "publish_datetime") as m_dt, + patch.object(publisher, "publish_json") as m_json, + ): + spies = { + "publish_bool": m_bool, + "publish_int": m_int, + "publish_float": m_float, + "publish_str": m_str, + "publish_datetime": m_dt, + "publish_json": m_json, + } + publisher.publish(KEY, value) + + spies[expected_method].assert_called_once_with(KEY, value, False, retain=True) + for name in TYPED_METHODS: + if name != expected_method: + spies[name].assert_not_called() + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_dict_routes_to_publish_json_with_retain( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + """`dict` values dispatch to `publish_json`, forwarding `retain`.""" + del publisher_label + publisher = factory() + payload: dict[str, Any] = {"a": 1, "b": "two"} + with patch.object(publisher, "publish_json") as m_json: + publisher.publish(KEY, payload, retain=False) + m_json.assert_called_once_with(KEY, payload, False, retain=False) + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +@pytest.mark.parametrize( + ("case_label", "value", "expected_method"), + PASSTHROUGH_CASES, + ids=[label for label, _, _ in PASSTHROUGH_CASES], +) +def test_publish_forwards_retain_false_to_every_arm( + publisher_label: str, + factory: Callable[[], Publisher], + case_label: str, + value: Publishable, + expected_method: str, +) -> None: + """`retain=False` reaches every typed dispatch target, not just `publish_json`.""" + del publisher_label, case_label + publisher = factory() + with patch.object(publisher, expected_method) as m: + publisher.publish(KEY, value, retain=False) + m.assert_called_once_with(KEY, value, False, retain=False) + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_datetime_stringifies_via_publish_str( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + """`publish_datetime` stringifies via `datetime_to_str` and forwards to `publish_str`.""" + del publisher_label + publisher = factory() + when = datetime(2026, 5, 9, 12, 34, 56, tzinfo=UTC) + with patch.object(publisher, "publish_str") as m_str: + publisher.publish_datetime(KEY, when) + m_str.assert_called_once_with(KEY, datetime_to_str(when), False, retain=True) + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_forwards_no_prefix_flag( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + del publisher_label + publisher = factory() + with patch.object(publisher, "publish_str") as m_str: + publisher.publish(KEY, "hello", no_prefix=True) + m_str.assert_called_once_with(KEY, "hello", True, retain=True) + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_true_routes_to_bool_not_int( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + """Locks in the bool-before-int dispatch ordering. + + `isinstance(True, int)` is `True` in Python, so a naive `int` check first + would silently route `True`/`False` to `publish_int`. + """ + del publisher_label + publisher = factory() + with ( + patch.object(publisher, "publish_bool") as m_bool, + patch.object(publisher, "publish_int") as m_int, + ): + publisher.publish(KEY, True) + m_bool.assert_called_once_with(KEY, True, False, retain=True) + m_int.assert_not_called() + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_int_does_not_route_to_bool( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + del publisher_label + publisher = factory() + with ( + patch.object(publisher, "publish_bool") as m_bool, + patch.object(publisher, "publish_int") as m_int, + ): + publisher.publish(KEY, 5) + m_int.assert_called_once_with(KEY, 5, False, retain=True) + m_bool.assert_not_called() + + +@pytest.mark.parametrize( + ("publisher_label", "factory"), + PUBLISHER_FACTORIES, + ids=[label for label, _ in PUBLISHER_FACTORIES], +) +def test_publish_unsupported_type_raises( + publisher_label: str, + factory: Callable[[], Publisher], +) -> None: + """Unsupported runtime types raise rather than silently no-op.""" + del publisher_label + publisher = factory() + with pytest.raises(TypeError, match="Unsupported value type"): + publisher.publish(KEY, b"bytes-not-supported") # type: ignore[arg-type] + + +class _MinimalPublisher(Publisher): + """ABC-level publisher that mocks only the typed publish methods. + + Keeps the dispatch contract pinned even if all concrete subclasses were + to override `publish` in the future. + """ + + def __init__(self, config: Configuration) -> None: + super().__init__(config) + self.publish_bool = MagicMock() # type: ignore[method-assign] + self.publish_int = MagicMock() # type: ignore[method-assign] + self.publish_float = MagicMock() # type: ignore[method-assign] + self.publish_str = MagicMock() # type: ignore[method-assign] + self.publish_datetime = MagicMock() # type: ignore[method-assign] + self.publish_json = MagicMock() # type: ignore[method-assign] + self.clear_topic = MagicMock() # type: ignore[method-assign] + + @override + async def connect(self) -> None: + pass + + @override + def enable_commands(self) -> None: + pass + + @override + def is_connected(self) -> bool: + return True + + @override + def publish_json( + self, + key: str, + data: dict[str, Any], + no_prefix: bool = False, + *, + retain: bool = True, + ) -> None: + pass + + @override + def publish_str( + self, key: str, value: str, no_prefix: bool = False, *, retain: bool = True + ) -> None: + pass + + @override + def publish_int( + self, key: str, value: int, no_prefix: bool = False, *, retain: bool = True + ) -> None: + pass + + @override + def publish_bool( + self, key: str, value: bool, no_prefix: bool = False, *, retain: bool = True + ) -> None: + pass + + @override + def publish_float( + self, key: str, value: float, no_prefix: bool = False, *, retain: bool = True + ) -> None: + pass + + @override + def clear_topic(self, key: str, no_prefix: bool = False) -> None: + pass + + +@pytest.mark.parametrize( + ("case_label", "value", "expected_method"), + PASSTHROUGH_CASES, + ids=[label for label, _, _ in PASSTHROUGH_CASES], +) +def test_abc_level_publish_dispatch( + case_label: str, + value: Publishable, + expected_method: str, +) -> None: + del case_label + publisher = _MinimalPublisher(_make_configuration()) + publisher.publish(KEY, value) + spies: dict[str, MagicMock] = { + "publish_bool": publisher.publish_bool, # type: ignore[dict-item] + "publish_int": publisher.publish_int, # type: ignore[dict-item] + "publish_float": publisher.publish_float, # type: ignore[dict-item] + "publish_str": publisher.publish_str, # type: ignore[dict-item] + "publish_datetime": publisher.publish_datetime, # type: ignore[dict-item] + "publish_json": publisher.publish_json, # type: ignore[dict-item] + } + spies[expected_method].assert_called_once_with(KEY, value, False, retain=True) + for name in TYPED_METHODS: + if name != expected_method: + spies[name].assert_not_called() + + +def test_abc_level_publish_dict_with_retain() -> None: + publisher = _MinimalPublisher(_make_configuration()) + payload: dict[str, Any] = {"x": 1} + publisher.publish(KEY, payload, retain=False) + publisher.publish_json.assert_called_once_with(KEY, payload, False, retain=False) # type: ignore[attr-defined] + + +def test_abc_level_publish_datetime_routes_to_publish_datetime() -> None: + publisher = _MinimalPublisher(_make_configuration()) + when = datetime(2026, 5, 9, 12, 34, 56, tzinfo=UTC) + publisher.publish(KEY, when) + publisher.publish_datetime.assert_called_once_with(KEY, when, False, retain=True) # type: ignore[attr-defined] diff --git a/tests/status_publisher/test_message_publisher.py b/tests/status_publisher/test_message_publisher.py index 8cdc598..a042375 100644 --- a/tests/status_publisher/test_message_publisher.py +++ b/tests/status_publisher/test_message_publisher.py @@ -156,14 +156,20 @@ def test_event_payload_keys(self) -> None: class TestMessageEventResilience(unittest.TestCase): def test_event_publish_failure_does_not_break_processing(self) -> None: publisher, capturing = _make_publisher() - original_publish = publisher._publish_directly - - def failing_publish(**kwargs: Any) -> bool: - if mqtt_topics.EVENTS_VEHICLE_MESSAGE in kwargs["topic"]: + original_publish = capturing.publish + + def failing_publish( + key: str, + value: Any, + no_prefix: bool = False, + *, + retain: bool = True, + ) -> None: + if mqtt_topics.EVENTS_VEHICLE_MESSAGE in key: raise RuntimeError("MQTT down") - return original_publish(**kwargs) + original_publish(key, value, no_prefix, retain=retain) - with patch.object(publisher, "_publish_directly", side_effect=failing_publish): + with patch.object(capturing, "publish", side_effect=failing_publish): result = publisher.publish(_make_message()) assert result.processed is True