diff --git a/plugboard-schemas/plugboard_schemas/_validation.py b/plugboard-schemas/plugboard_schemas/_validation.py index 98eb886c..ee71c175 100644 --- a/plugboard-schemas/plugboard_schemas/_validation.py +++ b/plugboard-schemas/plugboard_schemas/_validation.py @@ -18,6 +18,9 @@ from ._validator_registry import validator +_SYSTEM_STOP_EVENT = "system_stop" + + def _build_component_graph( connectors: dict[str, dict[str, _t.Any]], ) -> dict[str, set[str]]: @@ -98,9 +101,11 @@ def validate_all_inputs_connected( for comp_name, comp_data in components.items(): io = comp_data.get("io", {}) all_inputs = set(io.get("inputs", [])) + input_events = set(io.get("input_events", [])) + has_non_system_input_events = bool(input_events - {_SYSTEM_STOP_EVENT}) connected = connected_inputs.get(comp_name, set()) unconnected = all_inputs - connected - if unconnected: + if unconnected and not has_non_system_input_events: errors.append(f"Component '{comp_name}' has unconnected inputs: {sorted(unconnected)}") return errors diff --git a/plugboard/component/component.py b/plugboard/component/component.py index 6fe0ad20..35a2a4c6 100644 --- a/plugboard/component/component.py +++ b/plugboard/component/component.py @@ -345,6 +345,11 @@ async def _wrapper() -> None: with self._job_id_ctx(): await self._set_status(Status.RUNNING, publish=not self._is_running) await self._io_read_with_status_check() + # IO can close here once all producers for the component's event-only inputs have + # finished emitting. Return before rebinding inputs so the last event-populated + # field values are not replayed as if they were fresh inputs in another step. + if self.io.is_closed: + return await self._handle_events() self._bind_inputs() if self._can_step: @@ -365,6 +370,11 @@ async def _wrapper() -> None: def _has_field_inputs(self) -> bool: return len(self.io.inputs) > 0 + @property + def _has_connected_field_inputs(self) -> bool: + """Whether any declared field inputs are connected via input channels.""" + return self.io.has_connected_field_inputs + @cached_property def _has_event_inputs(self) -> bool: input_events = set([evt.safe_type() for evt in self.io.input_events]) @@ -409,7 +419,7 @@ async def _io_read_with_status_check(self) -> None: task.cancel() for task in done: exc = task.exception() - if isinstance(exc, EventStreamClosedError) and len(self.io.inputs) == 0: + if isinstance(exc, EventStreamClosedError) and not self._has_connected_field_inputs: await self.io.close() # Call close for final wait and flush event buffer elif exc is not None: raise exc @@ -422,7 +432,7 @@ async def _periodic_status_check(self) -> None: # TODO : Eventually producer graph update will be event driven. For now, # : the update is performed periodically, so it's called here along # : with the status check. - if len(self.io.inputs) == 0: + if not self._has_connected_field_inputs: await self._update_producer_graph() async def _status_check(self) -> None: diff --git a/plugboard/component/io_controller.py b/plugboard/component/io_controller.py index 7500aee2..5ac67f7c 100644 --- a/plugboard/component/io_controller.py +++ b/plugboard/component/io_controller.py @@ -86,8 +86,9 @@ def is_closed(self) -> bool: """Returns `True` if the `IOController` is closed, `False` otherwise.""" return self._is_closed - @cached_property - def _has_field_inputs(self) -> bool: + @property + def has_connected_field_inputs(self) -> bool: + """Returns whether any field inputs are connected via channels.""" return len(self._input_channels) > 0 @cached_property @@ -96,7 +97,7 @@ def _has_event_inputs(self) -> bool: @cached_property def _has_inputs(self) -> bool: - return self._has_field_inputs or self._has_event_inputs + return self.has_connected_field_inputs or self._has_event_inputs async def read(self, timeout: float | None = None) -> None: """Reads data and/or events from input channels. @@ -139,7 +140,7 @@ async def read(self, timeout: float | None = None) -> None: def _set_read_tasks(self) -> list[asyncio.Task]: read_tasks: list[asyncio.Task] = [] - if self._has_field_inputs: + if self.has_connected_field_inputs: if _fields_read_task not in self._read_tasks: read_fields_task = asyncio.create_task(self._read_fields(), name=_fields_read_task) self._read_tasks[_fields_read_task] = read_fields_task @@ -374,7 +375,7 @@ def _add_channel_for_event( def _create_input_field_group_tasks(self) -> None: """Groups input field channels by field name and launches read tasks for group inputs.""" - if not self._has_field_inputs: + if not self.has_connected_field_inputs: return field_channels: dict[str, list[tuple[_t_field_key, Channel]]] = defaultdict(list) for key, chan in self._input_channels.items(): diff --git a/tests/integration/test_process_with_components_run.py b/tests/integration/test_process_with_components_run.py index fe047ae8..8f48a2dc 100644 --- a/tests/integration/test_process_with_components_run.py +++ b/tests/integration/test_process_with_components_run.py @@ -23,6 +23,7 @@ ) from plugboard.events import Event from plugboard.exceptions import ConstraintError, NotInitialisedError, ProcessStatusError +from plugboard.library import FileWriter from plugboard.process import LocalProcess, Process, RayProcess from plugboard.schemas import ConnectorSpec, Status from tests.conftest import ComponentTestHelper, zmq_connector_cls @@ -459,6 +460,85 @@ async def test_event_driven_process_shutdown( await process.destroy() +class MessageEventData(BaseModel): + """Data for a message event.""" + + message: str + + +class MessageEvent(Event): + """Event carrying a file-writer message.""" + + type: _t.ClassVar[str] = "message_event" + data: MessageEventData + + +class MessageEventGenerator(ComponentTestHelper): + """Produces a fixed number of message events.""" + + io = IO(output_events=[MessageEvent]) + + def __init__(self, iters: int, *args: _t.Any, **kwargs: _t.Any) -> None: + super().__init__(*args, **kwargs) + self._iters = iters + + async def init(self) -> None: + await super().init() + self._seq = iter(range(self._iters)) + + async def step(self) -> None: + try: + idx = next(self._seq) + except StopIteration: + await self.io.close() + else: + evt = MessageEvent( + source=self.name, + data=MessageEventData(message=f"Message {idx}"), + ) + self.io.queue_event(evt) + await super().step() + + +class EventReaderFileWriter(FileWriter): + """`FileWriter` variant that adds event handling instead of a connector for `message`.""" + + io = IO(input_events=[MessageEvent]) + + @MessageEvent.handler + async def handle_message(self, event: MessageEvent) -> None: + self.message = event.data.message + + +@pytest.mark.asyncio +async def test_event_driven_file_writer_reuse(tmp_path: Path) -> None: + """Test that field-input components can be reused in event-driven processes.""" + output_path = tmp_path / "output_messages.csv" + components = [ + MessageEventGenerator(iters=3, name="message_event_generator"), + EventReaderFileWriter( + path=output_path, + name="event_reader_file_writer", + field_names=["message"], + ), + ] + event_connectors = AsyncioConnector.builder().build_event_connectors(components) + process = LocalProcess(components=components, connectors=event_connectors) + + await process.init() + await process.run() + + assert process.status == Status.COMPLETED + assert output_path.read_text().splitlines() == [ + "message", + "Message 0", + "Message 1", + "Message 2", + ] + + await process.destroy() + + _SHORT_TIMEOUT = 0.1 diff --git a/tests/unit/test_process_validation.py b/tests/unit/test_process_validation.py index 02e0a4d2..b0ec7482 100644 --- a/tests/unit/test_process_validation.py +++ b/tests/unit/test_process_validation.py @@ -303,6 +303,21 @@ def test_no_inputs_no_errors(self) -> None: errors = validate_all_inputs_connected(pd) assert errors == [] + def test_missing_inputs_allowed_for_event_driven_component_reuse(self) -> None: + """Unconnected inputs are allowed when non-system input events can populate them.""" + pd = _make_process_dict( + components={ + "producer": _make_component("producer", output_events=["message_event"]), + "writer": _make_component( + "writer", + inputs=["message"], + input_events=["system_stop", "message_event"], + ), + }, + ) + errors = validate_all_inputs_connected(pd) + assert errors == [] + # --------------------------------------------------------------------------- # Tests for validate_input_events