diff --git a/examples/ezmsg_toy.py b/examples/ezmsg_toy.py index 08f51f5..69f7141 100644 --- a/examples/ezmsg_toy.py +++ b/examples/ezmsg_toy.py @@ -24,20 +24,33 @@ class LFOSettings(ez.Settings): update_rate: float = 2.0 # Hz, update rate +class LFOState(ez.State): + start_time: float + cur_settings: LFOSettings + + class LFO(ez.Unit): SETTINGS = LFOSettings + STATE = LFOState OUTPUT = ez.OutputStream(float) + INPUT_SETTINGS = ez.InputStream(LFOSettings) + async def initialize(self) -> None: - self.start_time = time.time() + self.STATE.cur_settings = self.SETTINGS + self.STATE.start_time = time.time() + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: LFOSettings) -> None: + self.STATE.cur_settings = msg + @ez.publisher(OUTPUT) async def generate(self) -> AsyncGenerator: while True: - t = time.time() - self.start_time - yield self.OUTPUT, math.sin(2.0 * math.pi * self.SETTINGS.freq * t) - await asyncio.sleep(1.0 / self.SETTINGS.update_rate) + t = time.time() - self.STATE.start_time + yield self.OUTPUT, math.sin(2.0 * math.pi * self.STATE.cur_settings.freq * t) + await asyncio.sleep(1.0 / self.STATE.cur_settings.update_rate) # MESSAGE GENERATOR @@ -45,17 +58,30 @@ class MessageGeneratorSettings(ez.Settings): message: str +class MessageGeneratorState(ez.State): + cur_settings: MessageGeneratorSettings + + class MessageGenerator(ez.Unit): SETTINGS = MessageGeneratorSettings + STATE = MessageGeneratorState OUTPUT = ez.OutputStream(str) + INPUT_SETTINGS = ez.InputStream(MessageGeneratorSettings) + + async def initialize(self) -> None: + self.STATE.cur_settings = self.SETTINGS + + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: MessageGeneratorSettings) -> None: + self.STATE.cur_settings = msg @ez.publisher(OUTPUT) async def spawn_message(self) -> AsyncGenerator: while True: await asyncio.sleep(1.0) - ez.logger.info(f"Spawning {self.SETTINGS.message}") - yield self.OUTPUT, self.SETTINGS.message + ez.logger.info(f"Spawning {self.STATE.cur_settings.message}") + yield self.OUTPUT, self.STATE.cur_settings.message @ez.publisher(OUTPUT) async def spawn_once(self) -> AsyncGenerator: diff --git a/examples/profiling_tui.py b/examples/profiling_tui.py new file mode 100644 index 0000000..3f49556 --- /dev/null +++ b/examples/profiling_tui.py @@ -0,0 +1,450 @@ +#!/usr/bin/env python3 +""" +Simple live profiling TUI for ezmsg GraphServer. + +Features: +- Periodic profiling snapshot view broken out by publisher/subscriber endpoints +- Live trace sample counts via GraphContext.subscribe_profiling_trace() +- Optional automatic trace enablement for discovered processes + +Usage: + .venv/bin/python examples/profiling_tui.py --host 127.0.0.1 --port 25978 +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import time +from dataclasses import dataclass +from uuid import UUID + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ProcessProfilingSnapshot, + ProfilingStreamControl, + ProfilingTraceControl, +) +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT + + +def _truncate(text: object, width: int) -> str: + text = str(text) + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _fmt_float(value: float, digits: int = 2) -> str: + return f"{value:.{digits}f}" + + +@dataclass +class PublisherView: + process_id: UUID + topic: str + endpoint_id: str + published_total: int + published_window: int + publish_rate_hz: float + publish_delta_ms_avg: float + inflight_current: int + inflight_peak: int + trace_samples_seen: int + trace_last_age_s: float | None + backpressure_wait_ms_window: float + + +@dataclass +class SubscriberView: + process_id: UUID + topic: str + endpoint_id: str + channel_kind: str + received_total: int + received_window: int + lease_time_ms_avg: float + user_span_ms_avg: float + attributable_backpressure_ms_window: float + attributable_backpressure_events_total: int + trace_samples_seen: int + trace_last_age_s: float | None + + +class ProfilingTUI: + def __init__( + self, + ctx: GraphContext, + *, + snapshot_interval: float, + trace_interval: float, + trace_max_samples: int, + auto_trace: bool, + trace_sample_mod: int, + max_rows: int, + ) -> None: + self.ctx = ctx + self.snapshot_interval = max(0.2, snapshot_interval) + self.trace_interval = max(0.01, trace_interval) + self.trace_max_samples = max(1, trace_max_samples) + self.auto_trace = auto_trace + self.trace_sample_mod = max(1, trace_sample_mod) + self.max_rows = max(5, max_rows) + + self.snapshots: dict[UUID, ProcessProfilingSnapshot] = {} + self.route_units: dict[UUID, str] = {} + self.trace_enabled_processes: set[UUID] = set() + self.trace_errors: dict[UUID, str] = {} + self.trace_samples_seen_by_endpoint: dict[str, int] = {} + self.trace_last_timestamp_by_endpoint: dict[str, float] = {} + self.last_snapshot_time: float | None = None + + self._snapshot_task: asyncio.Task[None] | None = None + self._trace_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + await self._refresh_snapshot() + self._snapshot_task = asyncio.create_task(self._snapshot_loop()) + self._trace_task = asyncio.create_task(self._trace_loop()) + + async def close(self) -> None: + for task in (self._snapshot_task, self._trace_task): + if task is not None: + task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await task + + if not self.auto_trace: + return + + for process_id, route_unit in self.route_units.items(): + if process_id not in self.trace_enabled_processes: + continue + with contextlib.suppress(Exception): + await self.ctx.process_set_profiling_trace( + route_unit, + ProfilingTraceControl(enabled=False), + timeout=0.5, + ) + + async def _snapshot_loop(self) -> None: + while True: + await self._refresh_snapshot() + await asyncio.sleep(self.snapshot_interval) + + async def _refresh_snapshot(self) -> None: + graph_snapshot = await self.ctx.snapshot() + route_units: dict[UUID, str] = {} + for process in graph_snapshot.processes.values(): + if process.units: + route_units[process.process_id] = process.units[0] + self.route_units = route_units + + if self.auto_trace: + for process_id, route_unit in route_units.items(): + if process_id in self.trace_enabled_processes: + continue + try: + response = await self.ctx.process_set_profiling_trace( + route_unit, + ProfilingTraceControl( + enabled=True, + sample_mod=self.trace_sample_mod, + ), + timeout=0.5, + ) + if response.ok: + self.trace_enabled_processes.add(process_id) + self.trace_errors.pop(process_id, None) + else: + self.trace_errors[process_id] = str( + response.error or "unknown error" + ) + except Exception as exc: + self.trace_errors[process_id] = str(exc) + + self.snapshots = await self.ctx.profiling_snapshot_all( + timeout_per_process=max(0.1, self.snapshot_interval * 0.8) + ) + self.last_snapshot_time = time.time() + + async def _trace_loop(self) -> None: + async for batch in self.ctx.subscribe_profiling_trace( + ProfilingStreamControl( + interval=self.trace_interval, + max_samples=self.trace_max_samples, + ) + ): + for process_batch in batch.batches.values(): + for sample in process_batch.samples: + endpoint_id = sample.endpoint_id + self.trace_samples_seen_by_endpoint[endpoint_id] = ( + self.trace_samples_seen_by_endpoint.get(endpoint_id, 0) + 1 + ) + self.trace_last_timestamp_by_endpoint[endpoint_id] = batch.timestamp + + def _trace_for_endpoint(self, endpoint_id: str) -> tuple[int, float | None]: + now = time.time() + count = self.trace_samples_seen_by_endpoint.get(endpoint_id, 0) + ts = self.trace_last_timestamp_by_endpoint.get(endpoint_id) + age = None if ts is None else max(0.0, now - ts) + return count, age + + def _publisher_rows(self) -> list[PublisherView]: + rows: list[PublisherView] = [] + for process_id, snapshot in self.snapshots.items(): + for pub in snapshot.publishers.values(): + trace_count, trace_age = self._trace_for_endpoint(pub.endpoint_id) + rows.append( + PublisherView( + process_id=process_id, + topic=pub.topic, + endpoint_id=pub.endpoint_id, + published_total=pub.messages_published_total, + published_window=pub.messages_published_window, + publish_rate_hz=pub.publish_rate_hz_window, + publish_delta_ms_avg=pub.publish_delta_ns_avg_window / 1_000_000.0, + inflight_current=pub.inflight_messages_current, + inflight_peak=pub.inflight_messages_peak_window, + trace_samples_seen=trace_count, + trace_last_age_s=trace_age, + backpressure_wait_ms_window=( + pub.backpressure_wait_ns_window / 1_000_000.0 + ), + ) + ) + rows.sort( + key=lambda row: ( + -row.publish_rate_hz, + -row.published_total, + row.process_id, + row.topic, + ) + ) + return rows + + def _subscriber_rows(self) -> list[SubscriberView]: + rows: list[SubscriberView] = [] + for process_id, snapshot in self.snapshots.items(): + for sub in snapshot.subscribers.values(): + trace_count, trace_age = self._trace_for_endpoint(sub.endpoint_id) + channel_kind = ( + sub.channel_kind_last.value + if hasattr(sub.channel_kind_last, "value") + else str(sub.channel_kind_last) + ) + rows.append( + SubscriberView( + process_id=process_id, + topic=sub.topic, + endpoint_id=sub.endpoint_id, + channel_kind=channel_kind, + received_total=sub.messages_received_total, + received_window=sub.messages_received_window, + lease_time_ms_avg=sub.lease_time_ns_avg_window / 1_000_000.0, + user_span_ms_avg=sub.user_span_ns_avg_window / 1_000_000.0, + attributable_backpressure_ms_window=( + sub.attributable_backpressure_ns_window / 1_000_000.0 + ), + attributable_backpressure_events_total=( + sub.attributable_backpressure_events_total + ), + trace_samples_seen=trace_count, + trace_last_age_s=trace_age, + ) + ) + rows.sort( + key=lambda row: ( + -row.lease_time_ms_avg, + -row.received_total, + row.process_id, + row.topic, + ) + ) + return rows + + def render(self) -> None: + print("\x1bc", end="") + print("ezmsg profiling tui") + print("Ctrl-C to quit") + print( + "snapshot interval=" + f"{self.snapshot_interval:.2f}s, trace interval={self.trace_interval:.2f}s, " + f"trace max_samples={self.trace_max_samples}, auto_trace={self.auto_trace}" + ) + if self.last_snapshot_time is not None: + print( + "last snapshot age: " + f"{_fmt_float(max(0.0, time.time() - self.last_snapshot_time), 2)}s" + ) + print( + f"processes discovered={len(self.route_units)} " + f"publishers={sum(len(s.publishers) for s in self.snapshots.values())} " + f"subscribers={sum(len(s.subscribers) for s in self.snapshots.values())}" + ) + + publisher_rows = self._publisher_rows() + subscriber_rows = self._subscriber_rows() + + print("\nPublishers") + pub_header = ( + f"{'Process':<20} {'Topic':<26} {'Endpoint':<24} " + f"{'Total':>8} {'Win':>6} {'RateHz':>8} {'DeltaMs':>8} " + f"{'InFl':>5} {'InPk':>5} {'BPmsW':>8} {'Trace':>7} {'TAge':>6}" + ) + print(pub_header) + print("-" * len(pub_header)) + if not publisher_rows: + print("") + else: + for row in publisher_rows[: self.max_rows]: + trace_age = ( + "-" if row.trace_last_age_s is None else _fmt_float(row.trace_last_age_s, 2) + ) + print( + f"{_truncate(row.process_id, 20):<20} " + f"{_truncate(row.topic, 26):<26} " + f"{_truncate(row.endpoint_id, 24):<24} " + f"{row.published_total:>8} " + f"{row.published_window:>6} " + f"{_fmt_float(row.publish_rate_hz, 2):>8} " + f"{_fmt_float(row.publish_delta_ms_avg, 2):>8} " + f"{row.inflight_current:>5} " + f"{row.inflight_peak:>5} " + f"{_fmt_float(row.backpressure_wait_ms_window, 2):>8} " + f"{row.trace_samples_seen:>7} " + f"{trace_age:>6}" + ) + + print("\nSubscribers") + sub_header = ( + f"{'Process':<20} {'Topic':<26} {'Endpoint':<24} {'Kind':<6} " + f"{'Total':>8} {'Win':>6} {'LeaseMs':>8} {'UserMs':>8} " + f"{'BPmsW':>8} {'BPev':>6} {'Trace':>7} {'TAge':>6}" + ) + print(sub_header) + print("-" * len(sub_header)) + if not subscriber_rows: + print("") + else: + for row in subscriber_rows[: self.max_rows]: + trace_age = ( + "-" if row.trace_last_age_s is None else _fmt_float(row.trace_last_age_s, 2) + ) + print( + f"{_truncate(row.process_id, 20):<20} " + f"{_truncate(row.topic, 26):<26} " + f"{_truncate(row.endpoint_id, 24):<24} " + f"{_truncate(row.channel_kind, 6):<6} " + f"{row.received_total:>8} " + f"{row.received_window:>6} " + f"{_fmt_float(row.lease_time_ms_avg, 2):>8} " + f"{_fmt_float(row.user_span_ms_avg, 2):>8} " + f"{_fmt_float(row.attributable_backpressure_ms_window, 2):>8} " + f"{row.attributable_backpressure_events_total:>6} " + f"{row.trace_samples_seen:>7} " + f"{trace_age:>6}" + ) + + if self.trace_errors: + print("\ntrace errors:") + for process_id, err in sorted(self.trace_errors.items(), key=lambda item: str(item[0])): + print(f" {_truncate(str(process_id), 30)}: {_truncate(err, 120)}") + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(args: argparse.Namespace) -> None: + async with GraphContext( + _parse_address(args.host, args.port), auto_start=args.auto_start + ) as ctx: + tui = ProfilingTUI( + ctx, + snapshot_interval=args.snapshot_interval, + trace_interval=args.trace_interval, + trace_max_samples=args.max_samples, + auto_trace=args.auto_trace, + trace_sample_mod=args.sample_mod, + max_rows=args.max_rows, + ) + await tui.start() + try: + while True: + tui.render() + await asyncio.sleep(max(0.1, args.render_interval)) + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg profiling TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + parser.add_argument( + "--snapshot-interval", + type=float, + default=1.0, + help="Seconds between snapshot refreshes", + ) + parser.add_argument( + "--trace-interval", + type=float, + default=0.02, + help="Seconds between GraphServer trace stream batches", + ) + parser.add_argument( + "--max-samples", + type=int, + default=5000, + help="Max samples per process per streamed batch", + ) + parser.add_argument( + "--sample-mod", + type=int, + default=1, + help="Trace sampling divisor when auto-enabling trace", + ) + parser.add_argument( + "--render-interval", + type=float, + default=0.5, + help="Seconds between TUI redraws", + ) + parser.add_argument( + "--max-rows", + type=int, + default=30, + help="Max publisher/subscriber rows to render per table", + ) + parser.add_argument( + "--no-auto-trace", + action="store_true", + help="Do not auto-enable trace mode on discovered processes", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + args.auto_trace = not args.no_auto_trace + asyncio.run(_run_tui(args)) + + +if __name__ == "__main__": + main() diff --git a/examples/settings_tui.py b/examples/settings_tui.py new file mode 100644 index 0000000..83acf62 --- /dev/null +++ b/examples/settings_tui.py @@ -0,0 +1,389 @@ +#!/usr/bin/env python3 +""" +Simple settings TUI for ezmsg GraphServer. + +Features: +- Live settings view (push updates via GraphContext.subscribe_settings_events) +- Inspect component metadata and current settings snapshot +- Publish patched settings to components with dynamic INPUT_SETTINGS + +Usage: + .venv/bin/python examples/settings_tui.py --host 127.0.0.1 --port 25978 + +Commands: + help + refresh + inspect + set {"field": 123, "nested": {"gain": 0.5}} + quit + +Notes: +- Updates are sent over normal pub/sub to the component's INPUT_SETTINGS topic. +- For safe updates, the script expects pickled current settings to be available + and unpickleable in this environment. +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import json +import pickle +from dataclasses import dataclass, is_dataclass, replace +from typing import Any + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ComponentMetadataType, + GraphMetadata, + SettingsChangedEvent, + SettingsSnapshotValue, +) +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT +from ezmsg.core.pubclient import Publisher + + +def _truncate(text: str, width: int) -> str: + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _format_settings(value: SettingsSnapshotValue | None, width: int = 72) -> str: + if value is None: + return "-" + return _truncate(repr(value.repr_value), width) + + +def _deep_merge_dict(base: dict[str, Any], patch: dict[str, Any]) -> dict[str, Any]: + merged = dict(base) + for key, patch_value in patch.items(): + base_value = merged.get(key) + if isinstance(base_value, dict) and isinstance(patch_value, dict): + merged[key] = _deep_merge_dict(base_value, patch_value) + else: + merged[key] = patch_value + return merged + + +def _patch_dataclass(obj: Any, patch: dict[str, Any]) -> Any: + updates: dict[str, Any] = {} + for key, patch_value in patch.items(): + if not hasattr(obj, key): + raise KeyError(f"Settings object has no field '{key}'") + current = getattr(obj, key) + if is_dataclass(current) and isinstance(patch_value, dict): + updates[key] = _patch_dataclass(current, patch_value) + elif isinstance(current, dict) and isinstance(patch_value, dict): + updates[key] = _deep_merge_dict(current, patch_value) + else: + updates[key] = patch_value + return replace(obj, **updates) + + +def _patch_value(value: Any, patch: dict[str, Any]) -> Any: + if is_dataclass(value): + return _patch_dataclass(value, patch) + if isinstance(value, dict): + return _deep_merge_dict(value, patch) + raise TypeError( + f"Cannot patch settings value of type {type(value).__name__}. " + "Only dataclass/dict settings are supported by this script." + ) + + +def _components_from_metadata( + metadata: GraphMetadata | None, +) -> dict[str, ComponentMetadataType]: + if metadata is None: + return {} + return dict(metadata.components) + + +@dataclass +class ComponentRow: + address: str + name: str + component_type: str + settings_type: str + dynamic_enabled: bool + input_topic: str | None + + +class SettingsTUI: + def __init__(self, ctx: GraphContext): + self.ctx = ctx + self.settings: dict[str, SettingsSnapshotValue] = {} + self.components: dict[str, ComponentRow] = {} + self.row_addresses: list[str] = [] + self.last_seq = 0 + self.publishers: dict[str, Publisher] = {} + self._event_queue: asyncio.Queue[SettingsChangedEvent] = asyncio.Queue() + self._watch_task: asyncio.Task[None] | None = None + + async def initialize(self) -> None: + await self.refresh() + events = await self.ctx.settings_events(after_seq=0) + for event in events: + self.settings[event.component_address] = event.value + self.last_seq = max(self.last_seq, event.seq) + self._watch_task = asyncio.create_task(self._watch_settings_events()) + + async def close(self) -> None: + if self._watch_task is not None: + self._watch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._watch_task + + async def _watch_settings_events(self) -> None: + async for event in self.ctx.subscribe_settings_events(after_seq=self.last_seq): + await self._event_queue.put(event) + + async def refresh(self) -> None: + snapshot = await self.ctx.snapshot() + settings = await self.ctx.settings_snapshot() + + components: dict[str, ComponentRow] = {} + for session in snapshot.sessions.values(): + for address, comp in _components_from_metadata(session.metadata).items(): + components[address] = ComponentRow( + address=address, + name=comp.name, + component_type=comp.component_type, + settings_type=comp.settings_type, + dynamic_enabled=comp.dynamic_settings.enabled, + input_topic=comp.dynamic_settings.input_topic, + ) + + self.components = components + self.settings = settings + + async def drain_events(self) -> int: + count = 0 + while True: + try: + event = self._event_queue.get_nowait() + except asyncio.QueueEmpty: + break + self.settings[event.component_address] = event.value + self.last_seq = max(self.last_seq, event.seq) + count += 1 + return count + + def render(self, pending_updates: int = 0) -> None: + print("\x1bc", end="") + print("ezmsg settings tui") + print( + "Commands: help, refresh, inspect , " + "set , quit" + ) + if pending_updates > 0: + print(f"Applied {pending_updates} new settings event(s).") + + all_addresses = sorted(set(self.settings) | set(self.components)) + self.row_addresses = all_addresses + + header = ( + f"{'Row':<4} {'Component':<36} {'Dyn':<4} " + f"{'INPUT_SETTINGS Topic':<42} {'Current Settings':<72}" + ) + print() + print(header) + print("-" * len(header)) + + for idx, address in enumerate(all_addresses, start=1): + comp = self.components.get(address) + settings = self.settings.get(address) + + dynamic = "yes" if comp is not None and comp.dynamic_enabled else "no" + input_topic = ( + comp.input_topic if comp is not None and comp.input_topic is not None else "-" + ) + print( + f"{idx:<4} " + f"{_truncate(address, 36):<36} " + f"{dynamic:<4} " + f"{_truncate(input_topic, 42):<42} " + f"{_format_settings(settings):<72}" + ) + + def resolve_target(self, token: str) -> str: + if token.isdigit(): + idx = int(token) - 1 + if idx < 0 or idx >= len(self.row_addresses): + raise ValueError(f"Row index out of range: {token}") + return self.row_addresses[idx] + return token + + async def inspect(self, token: str) -> None: + address = self.resolve_target(token) + comp = self.components.get(address) + settings = self.settings.get(address) + print("\n--- inspect ---") + print(f"address: {address}") + if comp is None: + print("metadata: ") + else: + print(f"name: {comp.name}") + print(f"component_type: {comp.component_type}") + print(f"settings_type: {comp.settings_type}") + print(f"dynamic_settings.enabled: {comp.dynamic_enabled}") + print(f"dynamic_settings.input_topic: {comp.input_topic}") + if settings is None: + print("current_settings: ") + else: + print(f"repr: {settings.repr_value!r}") + print(f"has_pickled_payload: {settings.serialized is not None}") + if settings.serialized is not None: + try: + obj = pickle.loads(settings.serialized) + print(f"unpickled_type: {type(obj).__module__}.{type(obj).__name__}") + except Exception as exc: + print(f"unpickled_type: ") + + async def set_settings(self, token: str, patch: dict[str, Any]) -> str: + address = self.resolve_target(token) + comp = self.components.get(address) + if comp is None: + raise ValueError(f"No component metadata available for '{address}'") + if not comp.dynamic_enabled or comp.input_topic is None: + raise ValueError( + f"Component '{address}' is not dynamic-settings enabled or has no INPUT_SETTINGS topic" + ) + + current = self.settings.get(address) + if current is None: + raise ValueError(f"No current settings snapshot for '{address}'") + if current.serialized is None: + raise ValueError( + f"No serialized settings for '{address}'. Cannot safely build updated object." + ) + + try: + current_obj = pickle.loads(current.serialized) + except Exception as exc: + raise ValueError( + f"Could not unpickle current settings for '{address}': {exc}" + ) from exc + + updated_obj = _patch_value(current_obj, patch) + publisher = self.publishers.get(comp.input_topic) + if publisher is None: + publisher = await self.ctx.publisher(comp.input_topic) + self.publishers[comp.input_topic] = publisher + + await publisher.broadcast(updated_obj) + return f"Published settings update to {comp.input_topic}" + + +def _parse_patch(json_text: str) -> dict[str, Any]: + try: + patch = json.loads(json_text) + except json.JSONDecodeError as exc: + raise ValueError(f"Invalid JSON patch: {exc}") from exc + if not isinstance(patch, dict): + raise ValueError("Patch must be a JSON object") + return patch + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(host: str, port: int, auto_start: bool) -> None: + address = _parse_address(host, port) + + async with GraphContext(address, auto_start=auto_start) as ctx: + tui = SettingsTUI(ctx) + await tui.initialize() + try: + while True: + pending = await tui.drain_events() + tui.render(pending_updates=pending) + cmdline = (await asyncio.to_thread(input, "\nsettings-tui> ")).strip() + if not cmdline: + continue + + cmd, *rest = cmdline.split(" ", 1) + if cmd in {"q", "quit", "exit"}: + break + + if cmd in {"h", "help"}: + print( + "\nhelp:\n" + " refresh\n" + " inspect \n" + " set \n" + " quit\n" + ) + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + if cmd == "refresh": + await tui.refresh() + continue + + if cmd == "inspect": + if not rest: + print("Usage: inspect ") + else: + await tui.inspect(rest[0].strip()) + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + if cmd == "set": + if not rest: + print("Usage: set ") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + target_and_patch = rest[0].strip() + if " " not in target_and_patch: + print("Usage: set ") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + target, patch_text = target_and_patch.split(" ", 1) + try: + patch = _parse_patch(patch_text.strip()) + result = await tui.set_settings(target.strip(), patch) + print(result) + except Exception as exc: + print(f"set failed: {exc}") + await asyncio.to_thread(input, "Press Enter to continue...") + continue + + print(f"Unknown command: {cmd}") + await asyncio.to_thread(input, "Press Enter to continue...") + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg settings TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + asyncio.run(_run_tui(args.host, args.port, args.auto_start)) + + +if __name__ == "__main__": + main() diff --git a/examples/topology_tui.py b/examples/topology_tui.py new file mode 100644 index 0000000..56876a3 --- /dev/null +++ b/examples/topology_tui.py @@ -0,0 +1,285 @@ +#!/usr/bin/env python3 +""" +Simple live topology TUI for ezmsg GraphServer. + +Features: +- Push-based topology event subscription +- Live graph summary (nodes/edges/sessions/processes) +- Process ownership view +- Current edge list +- Recent topology event log + +Usage: + PYTHONPATH=src .venv/bin/python examples/topology_tui.py --host 127.0.0.1 --port 25978 +""" + +from __future__ import annotations + +import argparse +import asyncio +import contextlib +import time +from collections import deque + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import GraphSnapshot, TopologyChangedEvent +from ezmsg.core.netprotocol import DEFAULT_HOST, GRAPHSERVER_PORT_DEFAULT + + +def _truncate(text: object, width: int) -> str: + text = str(text) + if width <= 3: + return text[:width] + if len(text) <= width: + return text + return text[: width - 3] + "..." + + +def _fmt_age(age_s: float) -> str: + return f"{age_s:0.2f}s" + + +def _flatten_edges(snapshot: GraphSnapshot) -> list[tuple[str, str]]: + edges: list[tuple[str, str]] = [] + for src, destinations in snapshot.graph.items(): + for dst in destinations: + edges.append((src, dst)) + edges.sort(key=lambda edge: (edge[0], edge[1])) + return edges + + +class TopologyTUI: + def __init__( + self, + ctx: GraphContext, + *, + snapshot_interval: float, + render_interval: float, + max_edges: int, + max_events: int, + max_processes: int, + ) -> None: + self.ctx = ctx + self.snapshot_interval = max(0.2, snapshot_interval) + self.render_interval = max(0.1, render_interval) + self.max_edges = max(10, max_edges) + self.max_events = max(10, max_events) + self.max_processes = max(5, max_processes) + + self.snapshot: GraphSnapshot | None = None + self.last_snapshot_time: float | None = None + self._events: deque[TopologyChangedEvent] = deque(maxlen=self.max_events) + self._event_queue: asyncio.Queue[TopologyChangedEvent] = asyncio.Queue() + + self._watch_task: asyncio.Task[None] | None = None + + async def start(self) -> None: + await self._refresh_snapshot() + self._watch_task = asyncio.create_task(self._watch_topology_events()) + + async def close(self) -> None: + if self._watch_task is not None: + self._watch_task.cancel() + with contextlib.suppress(asyncio.CancelledError): + await self._watch_task + + async def _watch_topology_events(self) -> None: + after_seq = 0 + async for event in self.ctx.subscribe_topology_events(after_seq=after_seq): + after_seq = event.seq + await self._event_queue.put(event) + + async def _refresh_snapshot(self) -> None: + self.snapshot = await self.ctx.snapshot() + self.last_snapshot_time = time.time() + + async def update(self) -> int: + """ + Drain queued topology events and refresh snapshot if needed. + + Returns: + Number of drained events. + """ + drained = 0 + refresh_requested = False + + while True: + try: + event = self._event_queue.get_nowait() + except asyncio.QueueEmpty: + break + self._events.append(event) + refresh_requested = True + drained += 1 + + if self.last_snapshot_time is None: + await self._refresh_snapshot() + elif refresh_requested or (time.time() - self.last_snapshot_time) >= self.snapshot_interval: + await self._refresh_snapshot() + + return drained + + def render(self, drained_events: int) -> None: + print("\x1bc", end="") + print("ezmsg topology tui") + print("Ctrl-C to quit") + print( + f"snapshot_interval={self.snapshot_interval:.2f}s " + f"render_interval={self.render_interval:.2f}s" + ) + if self.last_snapshot_time is not None: + print(f"snapshot_age={_fmt_age(max(0.0, time.time() - self.last_snapshot_time))}") + if drained_events > 0: + print(f"applied_events={drained_events}") + + snapshot = self.snapshot + if snapshot is None: + print("\n") + return + + edges = _flatten_edges(snapshot) + node_names = set(snapshot.graph.keys()) + for _, dst in edges: + node_names.add(dst) + + print( + "\nsummary: " + f"nodes={len(node_names)} edges={len(edges)} " + f"sessions={len(snapshot.sessions)} processes={len(snapshot.processes)}" + ) + + print("\nprocesses") + proc_header = f"{'Process':<30} {'PID':>8} {'Host':<24} {'Units':<80}" + print(proc_header) + print("-" * len(proc_header)) + if not snapshot.processes: + print("") + else: + process_items = sorted(snapshot.processes.values(), key=lambda p: p.process_id) + for proc in process_items[: self.max_processes]: + units = ", ".join(proc.units) if proc.units else "-" + print( + f"{_truncate(proc.process_id, 30):<30} " + f"{str(proc.pid) if proc.pid is not None else '-':>8} " + f"{_truncate(proc.host if proc.host is not None else '-', 24):<24} " + f"{_truncate(units, 80):<80}" + ) + if len(process_items) > self.max_processes: + print(f"... {len(process_items) - self.max_processes} more process rows") + + print("\nedges") + edge_header = f"{'From':<48} {'To':<48}" + print(edge_header) + print("-" * len(edge_header)) + if not edges: + print("") + else: + for src, dst in edges[: self.max_edges]: + print(f"{_truncate(src, 48):<48} {_truncate(dst, 48):<48}") + if len(edges) > self.max_edges: + print(f"... {len(edges) - self.max_edges} more edges") + + print("\nrecent topology events") + event_header = ( + f"{'Seq':>6} {'Type':<15} {'Age':>8} {'Topics':<44} " + f"{'Source Session':<38} {'Source Process':<30}" + ) + print(event_header) + print("-" * len(event_header)) + if not self._events: + print("") + else: + now = time.time() + for event in reversed(self._events): + topics = ", ".join(event.changed_topics) if event.changed_topics else "-" + print( + f"{event.seq:>6} " + f"{_truncate(event.event_type.value, 15):<15} " + f"{_fmt_age(max(0.0, now - event.timestamp)):>8} " + f"{_truncate(topics, 44):<44} " + f"{_truncate(event.source_session_id or '-', 38):<38} " + f"{_truncate(event.source_process_id or '-', 30):<30}" + ) + + +def _parse_address(host: str, port: int) -> tuple[str, int]: + return (host, port) + + +async def _run_tui(args: argparse.Namespace) -> None: + async with GraphContext( + _parse_address(args.host, args.port), auto_start=args.auto_start + ) as ctx: + tui = TopologyTUI( + ctx, + snapshot_interval=args.snapshot_interval, + render_interval=args.render_interval, + max_edges=args.max_edges, + max_events=args.max_events, + max_processes=args.max_processes, + ) + await tui.start() + try: + while True: + drained = await tui.update() + tui.render(drained_events=drained) + await asyncio.sleep(tui.render_interval) + finally: + await tui.close() + + +def _build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="ezmsg topology TUI") + parser.add_argument("--host", default=DEFAULT_HOST, help="GraphServer host") + parser.add_argument( + "--port", + type=int, + default=GRAPHSERVER_PORT_DEFAULT, + help="GraphServer port", + ) + parser.add_argument( + "--auto-start", + action="store_true", + help="Allow GraphContext to auto-start GraphServer if unavailable", + ) + parser.add_argument( + "--snapshot-interval", + type=float, + default=1.0, + help="Seconds between forced snapshot refreshes", + ) + parser.add_argument( + "--render-interval", + type=float, + default=0.5, + help="Seconds between screen redraws", + ) + parser.add_argument( + "--max-edges", + type=int, + default=50, + help="Max edge rows to render", + ) + parser.add_argument( + "--max-events", + type=int, + default=25, + help="Max recent topology events to retain/render", + ) + parser.add_argument( + "--max-processes", + type=int, + default=20, + help="Max process rows to render", + ) + return parser + + +def main() -> None: + parser = _build_parser() + args = parser.parse_args() + asyncio.run(_run_tui(args)) + + +if __name__ == "__main__": + main() diff --git a/src/ezmsg/core/backend.py b/src/ezmsg/core/backend.py index 4318f1e..7a06fbd 100644 --- a/src/ezmsg/core/backend.py +++ b/src/ezmsg/core/backend.py @@ -1,7 +1,6 @@ import asyncio from collections.abc import Callable, Mapping, Iterable from collections.abc import Collection as AbstractCollection -from dataclasses import asdict, is_dataclass import enum import inspect import logging @@ -53,6 +52,11 @@ UnitMetadata, ) from .relay import _CollectionRelayUnit, _RelaySettings +from .settingsmeta import ( + settings_repr_value, + settings_schema_from_type, + settings_schema_from_value, +) from .graphserver import GraphService from .graphcontext import GraphContext @@ -424,15 +428,7 @@ def _stream_type_name(self, stream_type: object) -> str: return repr(stream_type) def _settings_repr(self, value: object) -> dict[str, object] | str: - if is_dataclass(value): - try: - asdict_value = asdict(value) - if isinstance(asdict_value, dict): - return asdict_value - except Exception: - pass - - return repr(value) + return settings_repr_value(value) def _settings_snapshot(self, value: object) -> tuple[bytes | None, dict[str, object] | str]: try: @@ -579,6 +575,11 @@ def _component_metadata(self) -> GraphMetadata: if inspect.isclass(settings_type) else repr(settings_type) ) + settings_schema = ( + settings_schema_from_value(comp.SETTINGS) + if comp.SETTINGS is not None + else settings_schema_from_type(settings_type) + ) component_common = dict( address=comp.address, @@ -587,6 +588,7 @@ def _component_metadata(self) -> GraphMetadata: settings_type=settings_type_name, initial_settings=self._settings_snapshot(comp.SETTINGS), dynamic_settings=dynamic_settings, + settings_schema=settings_schema, ) metadata_entry: ComponentMetadataType diff --git a/src/ezmsg/core/backendprocess.py b/src/ezmsg/core/backendprocess.py index 8ef22a7..5421425 100644 --- a/src/ezmsg/core/backendprocess.py +++ b/src/ezmsg/core/backendprocess.py @@ -3,15 +3,16 @@ import logging import inspect import os -import time +import pickle import traceback import threading import weakref +from copy import deepcopy from abc import abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, fields as dataclass_fields, is_dataclass, replace from collections import defaultdict -from collections.abc import Callable, Coroutine, Generator, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Generator, Mapping, Sequence from functools import wraps, partial from concurrent.futures import ThreadPoolExecutor from concurrent.futures.thread import _worker @@ -26,9 +27,24 @@ from .unit import Unit, TIMEIT_ATTR, SUBSCRIBES_ATTR from .graphcontext import GraphContext +from .graphmeta import ( + ProcessControlErrorCode, + ProcessControlOperation, + ProcessControlRequest, + ProcessControlResponse, + SettingsFieldUpdateRequest, + SettingsSnapshotValue, +) +from .profiling import PROFILES, PROFILE_TIME +from .processclient import ProcessControlClient from .pubclient import Publisher from .subclient import Subscriber from .netprotocol import AddressType +from .settingsmeta import ( + settings_repr_value, + settings_schema_from_value, + settings_structured_value, +) logger = logging.getLogger("ezmsg") @@ -220,12 +236,189 @@ class DefaultBackendProcess(BackendProcess): pubs: dict[str, Publisher] _shutdown_errors: bool + def _settings_snapshot_value(self, value: object) -> SettingsSnapshotValue: + try: + serialized = pickle.dumps(value) + except Exception: + serialized = None + + return SettingsSnapshotValue( + serialized=serialized, + repr_value=settings_repr_value(value), + structured_value=settings_structured_value(value), + settings_schema=settings_schema_from_value(value), + ) + + def _replace_settings_field( + self, settings_value: object, field_path: str, value: object + ) -> object: + if field_path == "": + raise ValueError("field_path must not be empty") + path = field_path.split(".") + + def apply(current: object, idx: int) -> object: + field_name = path[idx] + if isinstance(current, Mapping): + if field_name not in current: + raise AttributeError( + f"Settings field '{field_name}' does not exist in mapping" + ) + if idx == len(path) - 1: + updated = dict(current) + updated[field_name] = value + return updated + patched_child = apply(current[field_name], idx + 1) + updated = dict(current) + updated[field_name] = patched_child + return updated + + if not hasattr(current, field_name): + raise AttributeError( + f"Settings field '{field_name}' does not exist on " + f"{type(current).__name__}" + ) + + if idx == len(path) - 1: + return self._patch_object_field(current, field_name, value) + + child_value = getattr(current, field_name) + patched_child = apply(child_value, idx + 1) + return self._patch_object_field(current, field_name, patched_child) + + return apply(settings_value, 0) + + def _patch_object_field( + self, obj: object, field_name: str, value: object + ) -> object: + if is_dataclass(obj): + valid_fields = {f.name for f in dataclass_fields(obj)} + if field_name not in valid_fields: + raise AttributeError( + f"Settings field '{field_name}' does not exist on " + f"{type(obj).__name__}" + ) + return replace(obj, **{field_name: value}) + + if hasattr(obj, "model_copy") and callable(getattr(obj, "model_copy")): + return obj.model_copy(update={field_name: value}) # type: ignore[attr-defined] + + if hasattr(obj, "copy") and callable(getattr(obj, "copy")): + try: + return obj.copy(update={field_name: value}) # type: ignore[attr-defined] + except Exception: + pass + + if hasattr(obj, field_name): + patched = deepcopy(obj) + setattr(patched, field_name, value) + return patched + + raise TypeError(f"Cannot patch settings object of type {type(obj).__name__}") + def process(self, loop: asyncio.AbstractEventLoop) -> None: main_func = None context = GraphContext(self.graph_address) + process_client = ProcessControlClient(self.graph_address) + process_register_future: concurrent.futures.Future[None] | None = None coro_callables: dict[str, Callable[[], Coroutine[Any, Any, None]]] = dict() + settings_input_topics: dict[str, str] = {} + current_settings: dict[str, object] = {} + control_publishers: dict[str, Publisher] = {} self._shutdown_errors = False + async def process_request_handler( + request: ProcessControlRequest, + ) -> ProcessControlResponse: + if request.operation != ProcessControlOperation.UPDATE_SETTING_FIELD.value: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Unsupported process control operation: {request.operation}", + error_code=ProcessControlErrorCode.UNSUPPORTED_OPERATION, + process_id=process_client.process_id, + ) + + if request.payload is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Missing settings field update payload", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=process_client.process_id, + ) + + try: + update_obj = pickle.loads(request.payload) + if not isinstance(update_obj, SettingsFieldUpdateRequest): + raise RuntimeError( + "settings field update payload was not SettingsFieldUpdateRequest" + ) + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Invalid settings field update payload: {exc}", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=process_client.process_id, + ) + + unit_address = request.unit_address + input_topic = settings_input_topics.get(unit_address) + if input_topic is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + f"Unit '{unit_address}' does not expose INPUT_SETTINGS; " + "settings field update unsupported" + ), + error_code=ProcessControlErrorCode.UNSUPPORTED_OPERATION, + process_id=process_client.process_id, + ) + + if unit_address not in current_settings: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + f"No current settings value tracked for unit '{unit_address}'. " + "Send a full settings object first via update_settings()." + ), + error_code=ProcessControlErrorCode.HANDLER_ERROR, + process_id=process_client.process_id, + ) + + try: + patched = self._replace_settings_field( + current_settings[unit_address], + update_obj.field_path, + update_obj.value, + ) + control_pub = control_publishers.get(input_topic) + if control_pub is None: + control_pub = await context.publisher(input_topic) + control_publishers[input_topic] = control_pub + await control_pub.broadcast(patched) + current_settings[unit_address] = patched + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Failed to patch settings field: {exc}", + error_code=ProcessControlErrorCode.HANDLER_ERROR, + process_id=process_client.process_id, + ) + + result_value = self._settings_snapshot_value(patched) + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(result_value), + process_id=process_client.process_id, + ) + + process_client.set_request_handler(process_request_handler) + try: self.pubs = dict() @@ -259,6 +452,8 @@ async def setup_state(): main_func = None for unit in self.units: + if unit.SETTINGS is not None: + current_settings[unit.address] = unit.SETTINGS sub_callables: defaultdict[ str, set[Callable[..., Coroutine[Any, Any, None]]] ] = defaultdict(set) @@ -284,8 +479,32 @@ async def setup_state(): loop, ).result() task_name = f"SUBSCRIBER|{stream.address}" + report_settings_update: ( + Callable[[object], Awaitable[None]] | None + ) = None + if stream.name == "INPUT_SETTINGS": + component_address = unit.address + settings_input_topics[component_address] = stream.address + + async def report_settings_update_cb( + msg: object, + *, + _component_address: str = component_address, + ) -> None: + current_settings[_component_address] = msg + value = self._settings_snapshot_value(msg) + await process_client.report_settings_update( + component_address=_component_address, + value=value, + ) + + report_settings_update = report_settings_update_cb + coro_callables[task_name] = partial( - handle_subscriber, sub, sub_callables[stream.address] + handle_subscriber, + sub, + sub_callables[stream.address], + on_message=report_settings_update, ) elif isinstance(stream, OutputStream): @@ -315,6 +534,17 @@ async def setup_state(): logger.debug("Waiting at start barrier!") self.start_barrier.wait() + async def register_process_control() -> None: + try: + await process_client.register([unit.address for unit in self.units]) + except Exception as exc: + logger.warning(f"Process control registration failed: {exc}") + + process_register_future = asyncio.run_coroutine_threadsafe( + register_process_control(), + loop, + ) + for unit in self.units: for thread_fn in unit.threads.values(): loop.run_in_executor(None, thread_fn, unit) @@ -407,6 +637,15 @@ async def shutdown_units() -> None: except TimeoutError: logger.warning("Timed out waiting for retry on context revert") + process_close_future = asyncio.run_coroutine_threadsafe( + process_client.close(), + loop=loop, + ) + with suppress(Exception): + if process_register_future is not None: + process_register_future.result(timeout=0.5) + process_close_future.result() + logger.debug(f"Remaining tasks in event loop = {asyncio.all_tasks(loop)}") if self.task_finished_ev is not None: @@ -439,11 +678,12 @@ async def publish(stream: Stream, obj: Any) -> None: await asyncio.sleep(0) async def perf_publish(stream: Stream, obj: Any) -> None: - start = time.perf_counter() + start = PROFILE_TIME() await publish(stream, obj) - stop = time.perf_counter() + stop = PROFILE_TIME() logger.info( - f"{task_address} send duration = " + f"{(stop - start) * 1e3:0.4f}ms" + f"{task_address} send duration = " + f"{((stop - start) / 1_000_000.0):0.4f}ms" ) pub_fn = perf_publish if hasattr(task, TIMEIT_ATTR) else publish @@ -487,8 +727,10 @@ async def wrapped_task(msg: Any = None) -> None: except Exception: logger.error(f"Exception in Task: {task_address}") logger.error(traceback.format_exc()) - if self.term_ev.is_set(): - self._shutdown_errors = True + # Any task exception should mark shutdown as unclean so + # interrupt-driven teardown can return a non-zero exit code. + # Gating this on term_ev introduces timing-dependent behavior. + self._shutdown_errors = True if strict_shutdown: raise @@ -496,7 +738,9 @@ async def wrapped_task(msg: Any = None) -> None: async def handle_subscriber( - sub: Subscriber, callables: set[Callable[..., Coroutine[Any, Any, None]]] + sub: Subscriber, + callables: set[Callable[..., Coroutine[Any, Any, None]]], + on_message: Callable[[Any], Awaitable[None]] | None = None, ): """ Handle incoming messages from a subscriber and distribute to callables. @@ -524,9 +768,22 @@ async def handle_subscriber( if sub.leaky: msg = await sub.recv() try: + if on_message is not None: + try: + await on_message(msg) + except Exception as exc: + logger.warning( + f"Failed to report subscriber message metadata: {exc}" + ) for callable in list(callables): try: - await callable(msg) + span_start_ns = sub.begin_profile() + try: + await callable(msg) + finally: + sub.end_profile( + span_start_ns, getattr(callable, "__name__", None) + ) except (Complete, NormalTermination): callables.remove(callable) finally: @@ -534,9 +791,22 @@ async def handle_subscriber( else: async with sub.recv_zero_copy() as msg: try: + if on_message is not None: + try: + await on_message(msg) + except Exception as exc: + logger.warning( + f"Failed to report subscriber message metadata: {exc}" + ) for callable in list(callables): try: - await callable(msg) + span_start_ns = sub.begin_profile() + try: + await callable(msg) + finally: + sub.end_profile( + span_start_ns, getattr(callable, "__name__", None) + ) except (Complete, NormalTermination): callables.remove(callable) finally: diff --git a/src/ezmsg/core/graphcontext.py b/src/ezmsg/core/graphcontext.py index 3beddfd..8197038 100644 --- a/src/ezmsg/core/graphcontext.py +++ b/src/ezmsg/core/graphcontext.py @@ -22,14 +22,30 @@ from .graphserver import GraphServer, GraphService from .pubclient import Publisher from .subclient import Subscriber -from .graphmeta import GraphMetadata, GraphSnapshot +from .graphmeta import ( + ProcessControlOperation, + GraphMetadata, + GraphSnapshot, + ProcessPing, + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, + ProfilingTraceStreamBatch, + ProfilingStreamControl, + ProcessStats, + ProcessControlResponse, + ProfilingTraceControl, + SettingsFieldUpdateRequest, + SettingsChangedEvent, + SettingsSnapshotValue, + TopologyChangedEvent, +) logger = logging.getLogger("ezmsg") class _SessionResponseKind(enum.Enum): BYTE = enum.auto() - SNAPSHOT = enum.auto() + PICKLED = enum.auto() @dataclass @@ -43,25 +59,35 @@ class _SessionCommand: class GraphContext: """ - GraphContext maintains a list of created publishers, subscribers, and connections in the graph. - - The GraphContext provides a managed environment for creating and tracking publishers, - subscribers, and graph connections. When the context is no longer needed, it can - revert changes in the graph which disconnects publishers and removes modifications - that this context made. - - It also maintains a context manager that ensures the GraphServer is running. - - :param graph_service: Optional graph service instance to use - :type graph_service: GraphService | None + Session-scoped client for graph mutation, metadata, settings, and process control. + + `GraphContext` opens a session connection to `GraphServer` and acts as a control + plane for both low-level graph operations and high-level API introspection. + + Core capabilities: + - Create/track `Publisher` and `Subscriber` clients. + - Connect/disconnect topic edges owned by this session. + - Register high-level `GraphMetadata`. + - Read graph snapshots (topology, edge ownership, sessions, process ownership). + - Query settings snapshots/events and subscribe to push-based settings updates. + - Route process-control requests (ping/stats/profiling and custom operations). + - Revert all session-owned mutations on context exit (`SESSION_CLEAR`). + + Session semantics: + - Mutations and metadata are tied to the session lifecycle. + - If the session disconnects, session-owned graph state is dropped by server cleanup. + - Low-level pub/sub API usage remains supported independently of metadata. + + :param graph_address: Graph server address. If `None`, defaults are used. + :type graph_address: AddressType | None :param auto_start: Whether to auto-start a GraphServer if connection fails. If None, defaults to auto-start only when graph_address is not provided and no environment override is set. :type auto_start: bool | None .. note:: - The GraphContext is typically managed automatically by the ezmsg runtime - and doesn't need to be instantiated directly by user code. + `GraphContext` is used by the runtime, and can also be used directly by tools + (inspectors, profilers, dashboards, and operational scripts). """ _clients: set[Publisher | Subscriber] @@ -248,13 +274,13 @@ async def _session_io_loop(self) -> None: if cmd.response_kind == _SessionResponseKind.BYTE: response = await reader.read(1) - elif cmd.response_kind == _SessionResponseKind.SNAPSHOT: + elif cmd.response_kind == _SessionResponseKind.PICKLED: num_bytes = await read_int(reader) - snapshot_bytes = await reader.readexactly(num_bytes) + payload_bytes = await reader.readexactly(num_bytes) complete = await reader.read(1) if complete != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session snapshot") - response = pickle.loads(snapshot_bytes) + raise RuntimeError("Unexpected pickled response from session") + response = pickle.loads(payload_bytes) else: raise RuntimeError(f"Unsupported response kind: {cmd.response_kind}") @@ -334,18 +360,379 @@ async def register_metadata(self, metadata: GraphMetadata) -> None: payload=payload, response_kind=_SessionResponseKind.BYTE, ) - if response != Command.COMPLETE.value: - raise RuntimeError("Unexpected response to session metadata registration") + if response == Command.COMPLETE.value: + return + if response == Command.ERROR.value: + requested = set(metadata.components.keys()) + collisions: set[str] = set() + if len(requested) > 0: + own_session_id = str(self._session_id) if self._session_id is not None else None + try: + snapshot = await self.snapshot() + for session_id, session in snapshot.sessions.items(): + if own_session_id is not None and session_id == own_session_id: + continue + if session.metadata is None: + continue + collisions.update( + requested.intersection(session.metadata.components.keys()) + ) + except Exception: + # Fall back to a generic error if snapshot lookup fails. + pass + + if len(collisions) > 0: + collision_str = ", ".join(sorted(collisions)) + raise RuntimeError( + "Session metadata registration rejected by GraphServer due to " + f"component address collision(s): {collision_str}" + ) + raise RuntimeError("Session metadata registration rejected by GraphServer") + raise RuntimeError( + "Unexpected response to session metadata registration: " + f"{response!r}" + ) async def snapshot(self) -> GraphSnapshot: snapshot = await self._session_command( Command.SESSION_SNAPSHOT, - response_kind=_SessionResponseKind.SNAPSHOT, + response_kind=_SessionResponseKind.PICKLED, ) if not isinstance(snapshot, GraphSnapshot): raise RuntimeError("Session snapshot payload was not a GraphSnapshot") return snapshot + async def settings_snapshot(self) -> dict[str, SettingsSnapshotValue]: + snapshot = await self._session_command( + Command.SESSION_SETTINGS_SNAPSHOT, + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(snapshot, dict): + raise RuntimeError("Settings snapshot payload was not a dictionary") + if not all(isinstance(value, SettingsSnapshotValue) for value in snapshot.values()): + raise RuntimeError("Settings snapshot payload contained invalid values") + return snapshot + + async def settings_events(self, after_seq: int = 0) -> list[SettingsChangedEvent]: + events = await self._session_command( + Command.SESSION_SETTINGS_EVENTS, + str(after_seq), + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(events, list): + raise RuntimeError("Settings event payload was not a list") + if not all(isinstance(event, SettingsChangedEvent) for event in events): + raise RuntimeError("Settings event payload contained invalid entries") + return events + + async def settings_input_topic(self, component_address: str) -> str: + """ + Resolve the dynamic settings input topic for a component. + + The topic is discovered from currently registered session metadata. + Raises if the component is missing, does not opt in to dynamic settings, + or appears with conflicting dynamic settings topics. + """ + snapshot = await self.snapshot() + topics: set[str] = set() + for session in snapshot.sessions.values(): + metadata = session.metadata + if metadata is None: + continue + component = metadata.components.get(component_address) + if component is None: + continue + dynamic_settings = component.dynamic_settings + if dynamic_settings.enabled and dynamic_settings.input_topic is not None: + topics.add(dynamic_settings.input_topic) + + if len(topics) == 1: + return next(iter(topics)) + if len(topics) > 1: + raise RuntimeError( + "Conflicting dynamic settings topics for component " + f"'{component_address}': {sorted(topics)}" + ) + raise RuntimeError( + f"Component '{component_address}' does not expose dynamic settings metadata" + ) + + async def update_settings( + self, + component_address: str, + value: object, + *, + input_topic: str | None = None, + ) -> None: + """ + Publish a settings value to a component's `INPUT_SETTINGS` inlet. + + By default the target topic is resolved from metadata via + :meth:`settings_input_topic`. Supplying `input_topic` bypasses + metadata lookup. + """ + topic = input_topic if input_topic is not None else await self.settings_input_topic( + component_address + ) + pub = await self.publisher(topic) + try: + await pub.broadcast(value) + finally: + pub.close() + await pub.wait_closed() + self._clients.discard(pub) + + async def update_setting( + self, + component_address: str, + field_path: str, + value: object, + *, + timeout: float = 2.0, + ) -> SettingsSnapshotValue: + """ + Patch one field of a unit's current dynamic settings value. + + The patch is routed to the owning backend process, applied in-process + using dataclass replacement, and then published to `INPUT_SETTINGS`. + Returns a snapshot representation of the patched settings value. + """ + response = await self.process_request( + component_address, + ProcessControlOperation.UPDATE_SETTING_FIELD, + payload_obj=SettingsFieldUpdateRequest(field_path=field_path, value=value), + timeout=timeout, + ) + return typing.cast( + SettingsSnapshotValue, + self.decode_process_payload(response, SettingsSnapshotValue), + ) + + async def subscribe_settings_events( + self, + *, + after_seq: int = 0, + ) -> typing.AsyncIterator[SettingsChangedEvent]: + async for event in self._subscribe_pickled_stream( + command=Command.SESSION_SETTINGS_SUBSCRIBE, + setup_payload=encode_str(str(after_seq)), + expected_type=SettingsChangedEvent, + subscribe_error="Failed to subscribe to settings events", + payload_error="Settings subscription received invalid event payload", + ): + yield typing.cast(SettingsChangedEvent, event) + + async def subscribe_topology_events( + self, + *, + after_seq: int = 0, + ) -> typing.AsyncIterator[TopologyChangedEvent]: + async for event in self._subscribe_pickled_stream( + command=Command.SESSION_TOPOLOGY_SUBSCRIBE, + setup_payload=encode_str(str(after_seq)), + expected_type=TopologyChangedEvent, + subscribe_error="Failed to subscribe to topology events", + payload_error="Topology subscription received invalid event payload", + ): + yield typing.cast(TopologyChangedEvent, event) + + async def subscribe_profiling_trace( + self, + control: ProfilingStreamControl, + ) -> typing.AsyncIterator[ProfilingTraceStreamBatch]: + """ + Subscribe to streamed profiling trace batches from GraphServer. + """ + payload = pickle.dumps(control) + setup_payload = uint64_to_bytes(len(payload)) + payload + async for batch in self._subscribe_pickled_stream( + command=Command.SESSION_PROFILING_SUBSCRIBE, + setup_payload=setup_payload, + expected_type=ProfilingTraceStreamBatch, + subscribe_error="Failed to subscribe to profiling trace stream", + payload_error="Profiling subscription received invalid batch payload", + ): + yield typing.cast(ProfilingTraceStreamBatch, batch) + + async def _subscribe_pickled_stream( + self, + *, + command: Command, + setup_payload: bytes, + expected_type: type[object], + subscribe_error: str, + payload_error: str, + ) -> typing.AsyncIterator[object]: + reader, writer = await GraphService(self.graph_address).open_connection() + writer.write(command.value) + writer.write(setup_payload) + await writer.drain() + + _subscriber_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError(subscribe_error) + + try: + while True: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + value = pickle.loads(payload) + if not isinstance(value, expected_type): + raise RuntimeError(payload_error) + yield value + except asyncio.IncompleteReadError: + return + finally: + await close_stream_writer(writer) + + async def process_request( + self, + unit_address: str, + operation: ProcessControlOperation | str, + *, + payload: bytes | None = None, + payload_obj: object | None = None, + timeout: float = 2.0, + ) -> ProcessControlResponse: + if payload is not None and payload_obj is not None: + raise ValueError("Specify only one of payload or payload_obj") + + if payload_obj is not None: + payload = pickle.dumps(payload_obj) + + operation_name = ( + operation.value if isinstance(operation, ProcessControlOperation) else operation + ) + response = await self._session_command( + Command.SESSION_PROCESS_REQUEST, + unit_address, + operation_name, + str(timeout), + payload=payload if payload is not None else b"", + response_kind=_SessionResponseKind.PICKLED, + ) + if not isinstance(response, ProcessControlResponse): + raise RuntimeError("Session process request payload was not ProcessControlResponse") + return response + + async def process_ping( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessPing: + response = await self.process_request( + unit_address, + ProcessControlOperation.PING, + timeout=timeout, + ) + return typing.cast(ProcessPing, self.decode_process_payload(response, ProcessPing)) + + async def process_stats( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessStats: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROCESS_STATS, + timeout=timeout, + ) + return typing.cast( + ProcessStats, self.decode_process_payload(response, ProcessStats) + ) + + async def process_profiling_snapshot( + self, + unit_address: str, + *, + timeout: float = 2.0, + ) -> ProcessProfilingSnapshot: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROFILING_SNAPSHOT, + timeout=timeout, + ) + return typing.cast( + ProcessProfilingSnapshot, + self.decode_process_payload(response, ProcessProfilingSnapshot), + ) + + async def process_set_profiling_trace( + self, + unit_address: str, + control: ProfilingTraceControl, + *, + timeout: float = 2.0, + ) -> ProcessControlResponse: + return await self.process_request( + unit_address, + ProcessControlOperation.SET_PROFILING_TRACE, + payload_obj=control, + timeout=timeout, + ) + + async def process_profiling_trace_batch( + self, + unit_address: str, + *, + max_samples: int = 1000, + timeout: float = 2.0, + ) -> ProcessProfilingTraceBatch: + response = await self.process_request( + unit_address, + ProcessControlOperation.GET_PROFILING_TRACE_BATCH, + payload_obj=max_samples, + timeout=timeout, + ) + return typing.cast( + ProcessProfilingTraceBatch, + self.decode_process_payload(response, ProcessProfilingTraceBatch), + ) + + async def profiling_snapshot_all( + self, + *, + timeout_per_process: float = 0.5, + ) -> dict[UUID, ProcessProfilingSnapshot]: + graph_snapshot = await self.snapshot() + out: dict[UUID, ProcessProfilingSnapshot] = {} + for process in graph_snapshot.processes.values(): + if len(process.units) == 0: + continue + route_unit = process.units[0] + try: + out[process.process_id] = await self.process_profiling_snapshot( + route_unit, timeout=timeout_per_process + ) + except Exception: + continue + return out + + def decode_process_payload( + self, + response: ProcessControlResponse, + expected_type: type[object] = object, + ) -> object: + if not response.ok: + raise RuntimeError( + f"Process request failed ({response.error_code}): {response.error}" + ) + if response.payload is None: + raise RuntimeError("Process response did not include a payload") + decoded = pickle.loads(response.payload) + if expected_type is object: + return decoded + if not isinstance(decoded, expected_type): + raise RuntimeError( + "Unexpected process payload type: " + f"{type(decoded).__name__} (expected {expected_type.__name__})" + ) + return decoded + async def _shutdown_servers(self) -> None: if self._graph_server is not None: self._graph_server.stop() diff --git a/src/ezmsg/core/graphmeta.py b/src/ezmsg/core/graphmeta.py index a9b439e..c96e4d7 100644 --- a/src/ezmsg/core/graphmeta.py +++ b/src/ezmsg/core/graphmeta.py @@ -1,5 +1,8 @@ +import enum + from dataclasses import dataclass, field from typing import Any, TypeAlias, NamedTuple +from uuid import UUID @dataclass @@ -81,6 +84,25 @@ class TaskMetadata: publishes: list[str] = field(default_factory=list) +@dataclass +class SettingsFieldMetadata: + name: str + field_type: str + required: bool + default: Any + description: str | None + bounds: tuple[float | None, float | None] | None + choices: list[Any] | None + widget_hint: str | None + + +@dataclass +class SettingsSchemaMetadata: + provider: str + settings_type: str + fields: list[SettingsFieldMetadata] + + SettingsReprType: TypeAlias = dict[str, Any] | str SerializedSettingsType: TypeAlias = bytes | None InitialSettingsType: TypeAlias = tuple[SerializedSettingsType, SettingsReprType] @@ -94,6 +116,7 @@ class ComponentMetadata: settings_type: str initial_settings: InitialSettingsType dynamic_settings: DynamicSettingsMetadata + settings_schema: SettingsSchemaMetadata | None @dataclass @@ -123,6 +146,223 @@ class GraphMetadata: components: dict[str, ComponentMetadataType] +@dataclass +class ProcessRegistration: + pid: int + host: str + units: list[str] + + +@dataclass +class ProcessOwnershipUpdate: + added_units: list[str] = field(default_factory=list) + removed_units: list[str] = field(default_factory=list) + + +@dataclass +class SettingsSnapshotValue: + serialized: bytes | None + repr_value: dict[str, Any] | str + structured_value: dict[str, Any] | None = None + settings_schema: SettingsSchemaMetadata | None = None + + +class SettingsEventType(enum.Enum): + INITIAL_SETTINGS = "INITIAL_SETTINGS" + SETTINGS_UPDATED = "SETTINGS_UPDATED" + + +@dataclass +class SettingsChangedEvent: + seq: int + event_type: SettingsEventType + component_address: str + timestamp: float + source_session_id: str | None + source_process_id: UUID | None + value: SettingsSnapshotValue + + +class TopologyEventType(enum.Enum): + GRAPH_CHANGED = "GRAPH_CHANGED" + PROCESS_CHANGED = "PROCESS_CHANGED" + + +@dataclass +class TopologyChangedEvent: + seq: int + event_type: TopologyEventType + timestamp: float + changed_topics: list[str] + source_session_id: str | None + source_process_id: UUID | None + + +@dataclass +class ProcessSettingsUpdate: + component_address: str + value: SettingsSnapshotValue + timestamp: float + + +@dataclass +class ProcessControlRequest: + request_id: str + unit_address: str + operation: "ProcessControlOperation | str" + payload: bytes | None = None + + +class ProcessControlOperation(enum.Enum): + PING = "PING" + GET_PROCESS_STATS = "GET_PROCESS_STATS" + GET_PROFILING_SNAPSHOT = "GET_PROFILING_SNAPSHOT" + SET_PROFILING_TRACE = "SET_PROFILING_TRACE" + GET_PROFILING_TRACE_BATCH = "GET_PROFILING_TRACE_BATCH" + UPDATE_SETTING_FIELD = "UPDATE_SETTING_FIELD" + + +class ProcessControlErrorCode(enum.Enum): + UNROUTABLE_UNIT = "UNROUTABLE_UNIT" + ROUTE_WRITE_FAILED = "ROUTE_WRITE_FAILED" + TIMEOUT = "TIMEOUT" + PROCESS_DISCONNECTED = "PROCESS_DISCONNECTED" + UNSUPPORTED_OPERATION = "UNSUPPORTED_OPERATION" + HANDLER_NOT_CONFIGURED = "HANDLER_NOT_CONFIGURED" + HANDLER_ERROR = "HANDLER_ERROR" + INVALID_RESPONSE = "INVALID_RESPONSE" + + +@dataclass +class ProcessControlResponse: + request_id: str + ok: bool + payload: bytes | None = None + error: str | None = None + error_code: ProcessControlErrorCode | None = None + process_id: UUID | None = None + + +@dataclass +class SettingsFieldUpdateRequest: + field_path: str + value: Any + + +@dataclass +class ProcessPing: + process_id: UUID + pid: int + host: str + timestamp: float + + +@dataclass +class ProcessStats: + process_id: UUID + pid: int + host: str + owned_units: list[str] + timestamp: float + + +class ProfileChannelType(enum.Enum): + LOCAL = "LOCAL" + SHM = "SHM" + TCP = "TCP" + UNKNOWN = "UNKNOWN" + + +@dataclass +class PublisherProfileSnapshot: + endpoint_id: str + topic: str + messages_published_total: int + messages_published_window: int + publish_delta_ns_avg_window: float + publish_rate_hz_window: float + inflight_messages_current: int + num_buffers: int + inflight_messages_peak_window: int + backpressure_wait_ns_total: int + backpressure_wait_ns_window: int + timestamp: float + + +@dataclass +class SubscriberProfileSnapshot: + endpoint_id: str + topic: str + messages_received_total: int + messages_received_window: int + lease_time_ns_total: int + lease_time_ns_avg_window: float + user_span_ns_total: int + user_span_ns_avg_window: float + attributable_backpressure_ns_total: int + attributable_backpressure_ns_window: int + attributable_backpressure_events_total: int + channel_kind_last: ProfileChannelType + timestamp: float + + +@dataclass +class ProcessProfilingSnapshot: + process_id: UUID + pid: int + host: str + window_seconds: float + timestamp: float + publishers: dict[str, PublisherProfileSnapshot] + subscribers: dict[str, SubscriberProfileSnapshot] + + +@dataclass +class ProfilingTraceControl: + enabled: bool + sample_mod: int = 1 + publisher_topics: list[str] | None = None + subscriber_topics: list[str] | None = None + publisher_endpoint_ids: list[str] | None = None + subscriber_endpoint_ids: list[str] | None = None + metrics: list[str] | None = None + ttl_seconds: float | None = None + + +@dataclass +class ProfilingTraceSample: + timestamp: float + endpoint_id: str + topic: str + metric: str + value: float + channel_kind: ProfileChannelType | None = None + sample_seq: int | None = None + + +@dataclass +class ProcessProfilingTraceBatch: + process_id: UUID + pid: int + host: str + timestamp: float + samples: list[ProfilingTraceSample] + + +@dataclass +class ProfilingTraceStreamBatch: + timestamp: float + batches: dict[UUID, ProcessProfilingTraceBatch] + + +@dataclass +class ProfilingStreamControl: + interval: float = 0.05 + max_samples: int = 1000 + process_ids: list[UUID] | None = None + include_empty_batches: bool = False + + class Edge(NamedTuple): from_topic: str to_topic: str @@ -134,8 +374,17 @@ class SnapshotSession: metadata: GraphMetadata | None +@dataclass +class SnapshotProcess: + process_id: UUID + pid: int | None + host: str | None + units: list[str] + + @dataclass class GraphSnapshot: graph: dict[str, list[str]] edge_owners: dict[Edge, list[str]] sessions: dict[str, SnapshotSession] + processes: dict[UUID, SnapshotProcess] = field(default_factory=dict) diff --git a/src/ezmsg/core/graphserver.py b/src/ezmsg/core/graphserver.py index 873f120..db76aef 100644 --- a/src/ezmsg/core/graphserver.py +++ b/src/ezmsg/core/graphserver.py @@ -4,6 +4,9 @@ import os import socket import threading +import time +from collections import deque +from collections.abc import Sequence from contextlib import suppress from uuid import UUID, uuid1 @@ -13,14 +16,32 @@ from .graph_util import get_compactified_graph, graph_string, prune_graph_connections from .graphmeta import ( Edge, + ProcessControlOperation, + ProcessControlErrorCode, GraphMetadata, GraphSnapshot, + ProcessProfilingTraceBatch, + ProfilingTraceSample, + ProfilingTraceStreamBatch, + ProfilingStreamControl, + ProcessControlRequest, + ProcessControlResponse, + ProcessRegistration, + ProcessOwnershipUpdate, + ProcessSettingsUpdate, + SettingsChangedEvent, + SettingsEventType, + SettingsSnapshotValue, + TopologyChangedEvent, + TopologyEventType, + SnapshotProcess, SnapshotSession, ) from .netprotocol import ( Address, Command, ClientInfo, + ProcessInfo, SessionInfo, SubscriberInfo, PublisherInfo, @@ -73,6 +94,22 @@ class GraphServer(threading.Thread): _client_tasks: dict[UUID, "asyncio.Task[None]"] _command_lock: asyncio.Lock + _settings_current: dict[str, SettingsSnapshotValue] + _settings_source_session: dict[str, UUID | None] + _settings_source_process: dict[str, UUID | None] + _settings_events: list[SettingsChangedEvent] + _settings_event_seq: int + _settings_owned_by_session: dict[UUID, set[str]] + _settings_subscribers: dict[UUID, asyncio.Queue[object]] + _topology_events: list[TopologyChangedEvent] + _topology_event_seq: int + _topology_subscribers: dict[UUID, asyncio.Queue[object]] + _pending_process_requests: dict[ + str, tuple[UUID, "asyncio.Future[ProcessControlResponse]"] + ] + _profiling_trace_buffers: dict[UUID, deque[tuple[int, ProfilingTraceSample]]] + _profiling_trace_process_meta: dict[UUID, tuple[int, str]] + _profiling_trace_seq: dict[UUID, int] def __init__(self, **kwargs) -> None: super().__init__( @@ -88,6 +125,20 @@ def __init__(self, **kwargs) -> None: self._client_tasks = {} self.shms = {} self._address = None + self._settings_current = {} + self._settings_source_session = {} + self._settings_source_process = {} + self._settings_events = [] + self._settings_event_seq = 0 + self._settings_owned_by_session = {} + self._settings_subscribers = {} + self._topology_events = [] + self._topology_event_seq = 0 + self._topology_subscribers = {} + self._pending_process_requests = {} + self._profiling_trace_buffers = {} + self._profiling_trace_process_meta = {} + self._profiling_trace_seq = {} @property def address(self) -> Address: @@ -283,6 +334,71 @@ async def api( # to avoid closing writer return + elif req == Command.SESSION_SETTINGS_SUBSCRIBE.value: + subscriber_id = uuid1() + after_seq = int(await read_str(reader)) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_settings_subscriber( + subscriber_id, after_seq, reader, writer + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + + elif req == Command.SESSION_TOPOLOGY_SUBSCRIBE.value: + subscriber_id = uuid1() + after_seq = int(await read_str(reader)) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_topology_subscriber( + subscriber_id, after_seq, reader, writer + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + + elif req == Command.SESSION_PROFILING_SUBSCRIBE.value: + subscriber_id = uuid1() + stream_control = await self._read_profiling_stream_control(reader) + writer.write(encode_str(str(subscriber_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[subscriber_id] = asyncio.create_task( + self._handle_profiling_subscriber( + subscriber_id, + stream_control, + reader, + writer, + ) + ) + + # NOTE: Created a stream client, must return early + # to avoid closing writer + return + + elif req == Command.PROCESS.value: + process_client_id = uuid1() + self.clients[process_client_id] = ProcessInfo(process_client_id, writer) + writer.write(encode_str(str(process_client_id))) + writer.write(Command.COMPLETE.value) + await writer.drain() + self._client_tasks[process_client_id] = asyncio.create_task( + self._handle_process(process_client_id, reader, writer) + ) + + # NOTE: Created a process control client, must return early + # to avoid closing writer + return + else: # We only want to handle one command at a time async with self._command_lock: @@ -340,6 +456,12 @@ async def api( writer.write(Command.CYCLIC.value) if topology_changed: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=[to_topic], + source_session_id=None, + source_process_id=None, + ) await self._notify_downstream_for_topic(to_topic) await writer.drain() @@ -462,6 +584,21 @@ async def _handle_session( await self._handle_session_snapshot_request(writer) await writer.drain() + elif req == Command.SESSION_SETTINGS_SNAPSHOT.value: + await self._handle_session_settings_snapshot_request(writer) + await writer.drain() + + elif req == Command.SESSION_SETTINGS_EVENTS.value: + after_seq = int(await read_str(reader)) + await self._handle_session_settings_events_request( + writer, after_seq + ) + await writer.drain() + + elif req == Command.SESSION_PROCESS_REQUEST.value: + await self._handle_session_process_request(writer, reader) + await writer.drain() + else: logger.warning( f"Session {session_id} rx unknown command from GraphServer: {req}" @@ -480,6 +617,722 @@ async def _handle_session( self._client_tasks.pop(session_id, None) await close_stream_writer(writer) + def _process_info(self, process_client_id: UUID) -> ProcessInfo | None: + info = self.clients.get(process_client_id) + if isinstance(info, ProcessInfo): + return info + return None + + def _process_key(self, process_client_id: UUID) -> UUID: + return process_client_id + + async def _handle_process( + self, + process_client_id: UUID, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + logger.debug(f"Graph Server: Process control connected: {process_client_id}") + + try: + while True: + req = await reader.read(1) + + if not req: + break + + if req == Command.PROCESS_REGISTER.value: + response = await self._handle_process_register_request( + process_client_id, reader + ) + await self._write_process_response( + process_client_id, writer, response + ) + + elif req == Command.PROCESS_UPDATE_OWNERSHIP.value: + response = await self._handle_process_update_ownership_request( + process_client_id, reader + ) + await self._write_process_response( + process_client_id, writer, response + ) + + elif req == Command.PROCESS_SETTINGS_UPDATE.value: + response = await self._handle_process_settings_update_request( + process_client_id, reader + ) + await self._write_process_response( + process_client_id, writer, response + ) + + elif req == Command.PROCESS_PROFILING_TRACE_UPDATE.value: + await self._handle_process_profiling_trace_update_request( + process_client_id, reader + ) + + elif req == Command.PROCESS_ROUTE_RESPONSE.value: + await self._handle_process_route_response_request( + process_client_id, reader + ) + + else: + logger.warning( + f"Process control {process_client_id} rx unknown command: {req}" + ) + + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug( + f"Process control {process_client_id} disconnected from GraphServer: {e}" + ) + + finally: + process_info = self._process_info(process_client_id) + + async with self._command_lock: + request_ids = [ + request_id + for request_id, (owner_process_id, _) in self._pending_process_requests.items() + if owner_process_id == process_client_id + ] + for request_id in request_ids: + pending = self._pending_process_requests.pop(request_id, None) + if pending is None: + continue + _, response_fut = pending + if not response_fut.done(): + response_fut.set_result( + ProcessControlResponse( + request_id=request_id, + ok=False, + error="Owning process disconnected before response", + error_code=ProcessControlErrorCode.PROCESS_DISCONNECTED, + process_id=self._process_key(process_client_id), + ) + ) + self._remove_settings_for_process_locked(process_client_id) + if process_info is not None: + source_process_id = self._process_key(process_client_id) + self._profiling_trace_buffers.pop(source_process_id, None) + self._profiling_trace_process_meta.pop(source_process_id, None) + self._profiling_trace_seq.pop(source_process_id, None) + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=source_process_id, + ) + self.clients.pop(process_client_id, None) + self._client_tasks.pop(process_client_id, None) + await close_stream_writer(writer) + + async def _write_process_response( + self, + process_client_id: UUID, + fallback_writer: asyncio.StreamWriter, + response: bytes, + ) -> None: + process_info = self._process_info(process_client_id) + if process_info is None: + fallback_writer.write(response) + await fallback_writer.drain() + return + + async with process_info.write_lock: + writer = process_info.writer + writer.write(response) + await writer.drain() + + async def _read_pickled_payload(self, reader: asyncio.StreamReader) -> object: + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + return pickle.loads(payload) + + async def _read_typed_payload( + self, + reader: asyncio.StreamReader, + expected_type: type[object], + *, + log_prefix: str, + ) -> object | None: + try: + payload_obj = await self._read_pickled_payload(reader) + if not isinstance(payload_obj, expected_type): + raise RuntimeError( + f"payload was not {expected_type.__name__}: {type(payload_obj).__name__}" + ) + return payload_obj + except Exception as exc: + logger.warning("%s parse failed; ignoring payload: %s", log_prefix, exc) + return None + + def _queue_stream_event( + self, + queue: asyncio.Queue[object], + event: object, + ) -> None: + try: + queue.put_nowait(event) + except asyncio.QueueFull: + # Keep most recent samples under backpressure. + with suppress(asyncio.QueueEmpty): + queue.get_nowait() + with suppress(asyncio.QueueFull): + queue.put_nowait(event) + + async def _stream_sender( + self, + subscriber_id: UUID, + queue: asyncio.Queue[object], + writer: asyncio.StreamWriter, + label: str, + ) -> None: + try: + while True: + event = await queue.get() + payload = pickle.dumps(event) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + except (ConnectionResetError, BrokenPipeError): + logger.debug(f"{label} subscriber {subscriber_id} disconnected on send") + except asyncio.CancelledError: + raise + + async def _handle_event_subscriber( + self, + *, + subscriber_id: UUID, + after_seq: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + queue: asyncio.Queue[object], + subscribers: dict[UUID, asyncio.Queue[object]], + events: Sequence[object], + label: str, + ) -> None: + async with self._command_lock: + subscribers[subscriber_id] = queue + for event in events: + if getattr(event, "seq", 0) > after_seq: + self._queue_stream_event(queue, event) + + sender_task = asyncio.create_task( + self._stream_sender(subscriber_id, queue, writer, label), + name=f"{label}-sender-{subscriber_id}", + ) + + try: + while True: + req = await reader.read(1) + if not req: + break + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"{label} subscriber {subscriber_id} disconnected: {e}") + finally: + async with self._command_lock: + subscribers.pop(subscriber_id, None) + self._client_tasks.pop(subscriber_id, None) + sender_task.cancel() + with suppress(asyncio.CancelledError): + await sender_task + await close_stream_writer(writer) + + async def _handle_settings_subscriber( + self, + subscriber_id: UUID, + after_seq: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1024) + await self._handle_event_subscriber( + subscriber_id=subscriber_id, + after_seq=after_seq, + reader=reader, + writer=writer, + queue=queue, + subscribers=self._settings_subscribers, + events=self._settings_events, + label="settings", + ) + + async def _handle_topology_subscriber( + self, + subscriber_id: UUID, + after_seq: int, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + queue: asyncio.Queue[object] = asyncio.Queue(maxsize=1024) + await self._handle_event_subscriber( + subscriber_id=subscriber_id, + after_seq=after_seq, + reader=reader, + writer=writer, + queue=queue, + subscribers=self._topology_subscribers, + events=self._topology_events, + label="topology", + ) + + async def _read_profiling_stream_control( + self, reader: asyncio.StreamReader + ) -> ProfilingStreamControl: + payload_obj = await self._read_pickled_payload(reader) + if not isinstance(payload_obj, ProfilingStreamControl): + raise RuntimeError( + "Invalid profiling stream control payload type: " + f"{type(payload_obj).__name__}" + ) + return payload_obj + + async def _collect_profiling_trace_stream_batch( + self, + *, + stream_control: ProfilingStreamControl, + last_seq_by_process: dict[UUID, int], + ) -> ProfilingTraceStreamBatch: + process_ids_filter = ( + set(stream_control.process_ids) + if stream_control.process_ids is not None + else None + ) + max_samples = max(1, int(stream_control.max_samples)) + now_ts = time.time() + batches: dict[UUID, ProcessProfilingTraceBatch] = {} + + async with self._command_lock: + connected_processes: dict[UUID, tuple[int, str]] = {} + for client_id, info in self.clients.items(): + if not isinstance(info, ProcessInfo): + continue + process_id = self._process_key(client_id) + pid = info.pid if info.pid is not None else -1 + host = info.host if info.host is not None else "" + connected_processes[process_id] = (pid, host) + + process_ids: list[UUID] + if process_ids_filter is not None: + process_ids = sorted(process_ids_filter, key=str) + else: + process_ids = sorted(connected_processes.keys(), key=str) + + for process_id in process_ids: + sample_buffer = self._profiling_trace_buffers.get(process_id) + samples: list[ProfilingTraceSample] = [] + if sample_buffer: + last_seq = last_seq_by_process.get(process_id, 0) + oldest_seq = sample_buffer[0][0] + if last_seq < oldest_seq - 1: + last_seq = oldest_seq - 1 + for seq, sample in sample_buffer: + if seq <= last_seq: + continue + samples.append(sample) + last_seq = seq + if len(samples) >= max_samples: + break + last_seq_by_process[process_id] = last_seq + + if len(samples) == 0 and not stream_control.include_empty_batches: + continue + + pid, host = connected_processes.get( + process_id, + self._profiling_trace_process_meta.get(process_id, (-1, "")), + ) + batches[process_id] = ProcessProfilingTraceBatch( + process_id=process_id, + pid=pid, + host=host, + timestamp=now_ts, + samples=samples, + ) + + return ProfilingTraceStreamBatch(timestamp=now_ts, batches=batches) + + async def _handle_profiling_subscriber( + self, + subscriber_id: UUID, + stream_control: ProfilingStreamControl, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + ) -> None: + interval = max(0.01, float(stream_control.interval)) + last_seq_by_process: dict[UUID, int] = {} + try: + while True: + try: + req = await asyncio.wait_for(reader.read(1), timeout=interval) + if not req: + break + # No control commands currently supported on this stream. + continue + except asyncio.TimeoutError: + pass + + batch = await self._collect_profiling_trace_stream_batch( + stream_control=stream_control, + last_seq_by_process=last_seq_by_process, + ) + if len(batch.batches) == 0: + continue + + payload = pickle.dumps(batch) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + except (ConnectionResetError, BrokenPipeError) as e: + logger.debug(f"Profiling subscriber {subscriber_id} disconnected: {e}") + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error( + "Profiling subscriber %s failed: %s", + subscriber_id, + exc, + ) + finally: + self._client_tasks.pop(subscriber_id, None) + await close_stream_writer(writer) + + async def _handle_process_profiling_trace_update_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> None: + batch = await self._read_typed_payload( + reader, + ProcessProfilingTraceBatch, + log_prefix=f"Process control {process_client_id} trace update", + ) + if batch is None: + return + + async with self._command_lock: + process_id = self._process_key(process_client_id) + trace_buffer = self._profiling_trace_buffers.setdefault( + process_id, deque(maxlen=200_000) + ) + next_seq = self._profiling_trace_seq.get(process_id, 0) + for sample in batch.samples: + next_seq += 1 + trace_buffer.append((next_seq, sample)) + self._profiling_trace_seq[process_id] = next_seq + self._profiling_trace_process_meta[process_id] = (batch.pid, batch.host) + + async def _handle_process_register_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + registration = await self._read_typed_payload( + reader, + ProcessRegistration, + log_prefix=f"Process control {process_client_id} registration", + ) + if registration is None: + return Command.ERROR.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.ERROR.value + + conflicts = sorted( + { + unit + for unit in set(registration.units) + if ( + (owner := self._process_owner_for_unit(unit)) is not None + and owner != process_client_id + ) + } + ) + if conflicts: + logger.warning( + "Process control %s register rejected due to unit ownership conflict(s): %s", + process_client_id, + ", ".join(conflicts), + ) + return Command.ERROR.value + + prev_units = set(process_info.units) + process_info.pid = registration.pid + process_info.host = registration.host + process_info.units = set(registration.units) + if prev_units != process_info.units: + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=self._process_key(process_client_id), + ) + + return Command.COMPLETE.value + + async def _handle_process_update_ownership_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + update = await self._read_typed_payload( + reader, + ProcessOwnershipUpdate, + log_prefix=f"Process control {process_client_id} ownership update", + ) + if update is None: + return Command.ERROR.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.ERROR.value + + conflicts = sorted( + { + unit + for unit in set(update.added_units) + if ( + (owner := self._process_owner_for_unit(unit)) is not None + and owner != process_client_id + ) + } + ) + if conflicts: + logger.warning( + "Process control %s ownership update rejected due to unit ownership conflict(s): %s", + process_client_id, + ", ".join(conflicts), + ) + return Command.ERROR.value + + prev_units = set(process_info.units) + process_info.units.update(update.added_units) + process_info.units.difference_update(update.removed_units) + if prev_units != process_info.units: + self._append_topology_event_locked( + event_type=TopologyEventType.PROCESS_CHANGED, + changed_topics=[], + source_session_id=None, + source_process_id=self._process_key(process_client_id), + ) + + return Command.COMPLETE.value + + async def _handle_process_settings_update_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> bytes: + update = await self._read_typed_payload( + reader, + ProcessSettingsUpdate, + log_prefix=f"Process control {process_client_id} settings update", + ) + if update is None: + return Command.ERROR.value + + async with self._command_lock: + process_info = self._process_info(process_client_id) + if process_info is None: + return Command.ERROR.value + if update.component_address not in process_info.units: + metadata_owner = self._session_owner_for_component_locked( + update.component_address + ) + known_owner = self._process_owner_for_unit(update.component_address) + allow_startup_race = ( + len(process_info.units) == 0 and metadata_owner is not None + ) + if known_owner == process_client_id or allow_startup_race: + pass + else: + logger.warning( + "Process control %s settings update rejected for unowned component: %s", + process_client_id, + update.component_address, + ) + return Command.ERROR.value + else: + metadata_owner = self._session_owner_for_component_locked( + update.component_address + ) + + if metadata_owner is None: + source_session_id = self._settings_source_session.get( + update.component_address + ) + else: + source_session_id = metadata_owner + + source_process_id = self._process_key(process_client_id) + self._settings_current[update.component_address] = update.value + self._settings_source_session[update.component_address] = source_session_id + self._settings_source_process[update.component_address] = source_process_id + self._append_settings_event_locked( + event_type=SettingsEventType.SETTINGS_UPDATED, + component_address=update.component_address, + value=update.value, + source_session_id=( + str(source_session_id) if source_session_id is not None else None + ), + source_process_id=source_process_id, + timestamp=update.timestamp, + ) + + return Command.COMPLETE.value + + async def _handle_process_route_response_request( + self, process_client_id: UUID, reader: asyncio.StreamReader + ) -> None: + response = await self._read_typed_payload( + reader, + ProcessControlResponse, + log_prefix=f"Process control {process_client_id} route response", + ) + if response is None: + return + + async with self._command_lock: + pending = self._pending_process_requests.pop(response.request_id, None) + + if pending is None: + logger.warning( + "Process control %s returned unknown request_id: %s", + process_client_id, + response.request_id, + ) + return + + owner_process_id, response_fut = pending + if owner_process_id != process_client_id: + if not response_fut.done(): + response_fut.set_result( + ProcessControlResponse( + request_id=response.request_id, + ok=False, + error=( + "Received response from unexpected process " + f"{process_client_id}; expected {owner_process_id}" + ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=response.process_id, + ) + ) + return + + if not response_fut.done(): + response_fut.set_result(response) + + def _process_for_unit(self, unit_address: str) -> ProcessInfo | None: + for info in self.clients.values(): + if isinstance(info, ProcessInfo) and unit_address in info.units: + return info + return None + + def _process_owner_for_unit(self, unit_address: str) -> UUID | None: + for client_id, info in self.clients.items(): + if isinstance(info, ProcessInfo) and unit_address in info.units: + return client_id + return None + + def _metadata_collisions( + self, session_id: UUID, metadata: GraphMetadata + ) -> list[str]: + collisions: list[str] = [] + requested = set(metadata.components.keys()) + if not requested: + return collisions + for other_session_id, info in self.clients.items(): + if other_session_id == session_id or not isinstance(info, SessionInfo): + continue + if info.metadata is None: + continue + overlap = requested.intersection(info.metadata.components.keys()) + collisions.extend(overlap) + return sorted(set(collisions)) + + async def _route_process_request( + self, + unit_address: str, + operation: str, + payload: bytes | None, + timeout: float, + ) -> ProcessControlResponse: + request_id = str(uuid1()) + response_fut: asyncio.Future[ProcessControlResponse] = ( + asyncio.get_running_loop().create_future() + ) + request = ProcessControlRequest( + request_id=request_id, + unit_address=unit_address, + operation=operation, + payload=payload, + ) + + async with self._command_lock: + process_info = self._process_for_unit(unit_address) + if process_info is None: + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=f"No process owns unit '{unit_address}'", + error_code=ProcessControlErrorCode.UNROUTABLE_UNIT, + ) + + self._pending_process_requests[request_id] = (process_info.id, response_fut) + + try: + async with process_info.write_lock: + process_writer = process_info.writer + request_bytes = pickle.dumps(request) + process_writer.write(Command.PROCESS_ROUTE_REQUEST.value) + process_writer.write(uint64_to_bytes(len(request_bytes))) + process_writer.write(request_bytes) + await process_writer.drain() + except Exception as exc: + self._pending_process_requests.pop(request_id, None) + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=f"Failed to route request to owning process: {exc}", + error_code=ProcessControlErrorCode.ROUTE_WRITE_FAILED, + process_id=self._process_key(process_info.id), + ) + + try: + return await asyncio.wait_for(response_fut, timeout=timeout) + except asyncio.TimeoutError: + async with self._command_lock: + self._pending_process_requests.pop(request_id, None) + return ProcessControlResponse( + request_id=request_id, + ok=False, + error=( + f"Timed out waiting for process response " + f"(unit={unit_address}, operation={operation}, timeout={timeout}s)" + ), + error_code=ProcessControlErrorCode.TIMEOUT, + process_id=self._process_key(process_info.id), + ) + + async def _handle_session_process_request( + self, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + ) -> None: + unit_address = await read_str(reader) + operation = await read_str(reader) + timeout = float(await read_str(reader)) + payload_size = await read_int(reader) + payload: bytes | None = None + if payload_size > 0: + payload = await reader.readexactly(payload_size) + + response = await self._route_process_request( + unit_address=unit_address, + operation=operation, + payload=payload, + timeout=timeout, + ) + response_bytes = pickle.dumps(response) + writer.write(uint64_to_bytes(len(response_bytes))) + writer.write(response_bytes) + writer.write(Command.COMPLETE.value) + def _connect_owner( self, from_topic: str, to_topic: str, owner: UUID | str | None ) -> bool: @@ -506,6 +1359,157 @@ def _session_info(self, session_id: UUID) -> SessionInfo | None: return info return None + def _append_settings_event_locked( + self, + event_type: SettingsEventType, + component_address: str, + value: SettingsSnapshotValue, + source_session_id: str | None, + source_process_id: UUID | None, + timestamp: float | None = None, + ) -> None: + self._settings_event_seq += 1 + event = SettingsChangedEvent( + seq=self._settings_event_seq, + event_type=event_type, + component_address=component_address, + timestamp=timestamp if timestamp is not None else time.time(), + source_session_id=source_session_id, + source_process_id=source_process_id, + value=value, + ) + self._settings_events.append(event) + + for queue in self._settings_subscribers.values(): + self._queue_stream_event(queue, event) + + # Bound memory growth for long-lived servers. + max_events = 10_000 + if len(self._settings_events) > max_events: + del self._settings_events[0 : len(self._settings_events) - max_events] + + def _append_topology_event_locked( + self, + event_type: TopologyEventType, + changed_topics: list[str], + source_session_id: str | None, + source_process_id: UUID | None, + timestamp: float | None = None, + ) -> None: + self._topology_event_seq += 1 + event = TopologyChangedEvent( + seq=self._topology_event_seq, + event_type=event_type, + timestamp=timestamp if timestamp is not None else time.time(), + changed_topics=sorted(set(changed_topics)), + source_session_id=source_session_id, + source_process_id=source_process_id, + ) + self._topology_events.append(event) + + for queue in self._topology_subscribers.values(): + self._queue_stream_event(queue, event) + + max_events = 10_000 + if len(self._topology_events) > max_events: + del self._topology_events[0 : len(self._topology_events) - max_events] + + def _remove_settings_for_session_locked(self, session_id: UUID) -> None: + component_addresses = self._settings_owned_by_session.pop(session_id, set()) + for component_address in component_addresses: + if self._settings_source_session.get(component_address) == session_id: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + + def _session_owner_for_component_locked(self, component_address: str) -> UUID | None: + for client_id, info in self.clients.items(): + if not isinstance(info, SessionInfo): + continue + if info.metadata is None: + continue + if component_address in info.metadata.components: + return client_id + return None + + def _initial_settings_for_component_locked( + self, session_id: UUID, component_address: str + ) -> SettingsSnapshotValue | None: + session = self._session_info(session_id) + if session is None or session.metadata is None: + return None + component = session.metadata.components.get(component_address) + if component is None: + return None + initial_repr = component.initial_settings[1] + return SettingsSnapshotValue( + serialized=component.initial_settings[0], + repr_value=initial_repr, + structured_value=initial_repr if isinstance(initial_repr, dict) else None, + settings_schema=component.settings_schema, + ) + + def _remove_settings_for_process_locked(self, process_client_id: UUID) -> None: + source_process_id = self._process_key(process_client_id) + component_addresses = [ + component_address + for component_address, owner_process_id in self._settings_source_process.items() + if owner_process_id == source_process_id + ] + + for component_address in component_addresses: + source_session_id = self._settings_source_session.get(component_address) + if source_session_id is None: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + continue + + restored = self._initial_settings_for_component_locked( + source_session_id, component_address + ) + if restored is None: + self._settings_current.pop(component_address, None) + self._settings_source_session.pop(component_address, None) + self._settings_source_process.pop(component_address, None) + continue + + self._settings_current[component_address] = restored + self._settings_source_process[component_address] = None + self._append_settings_event_locked( + event_type=SettingsEventType.SETTINGS_UPDATED, + component_address=component_address, + value=restored, + source_session_id=str(source_session_id), + source_process_id=None, + ) + + def _apply_session_metadata_settings_locked( + self, session_id: UUID, metadata: GraphMetadata + ) -> None: + session_components: set[str] = set() + for component in metadata.components.values(): + initial_repr = component.initial_settings[1] + value = SettingsSnapshotValue( + serialized=component.initial_settings[0], + repr_value=initial_repr, + structured_value=initial_repr if isinstance(initial_repr, dict) else None, + settings_schema=component.settings_schema, + ) + self._settings_current[component.address] = value + self._settings_source_session[component.address] = session_id + self._settings_source_process[component.address] = None + session_components.add(component.address) + self._append_settings_event_locked( + event_type=SettingsEventType.INITIAL_SETTINGS, + component_address=component.address, + value=value, + source_session_id=str(session_id), + source_process_id=None, + ) + + self._settings_owned_by_session[session_id] = session_components + async def _handle_session_edge_request( self, session_id: UUID, @@ -529,6 +1533,12 @@ async def _handle_session_edge_request( return Command.CYCLIC.value if topology_changed: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=[to_topic], + source_session_id=str(session_id), + source_process_id=None, + ) await self._notify_downstream_for_topic(to_topic) return Command.COMPLETE.value @@ -543,24 +1553,26 @@ async def _handle_session_clear_request(self, session_id: UUID) -> bytes: async def _handle_session_register_request( self, session_id: UUID, reader: asyncio.StreamReader ) -> bytes: - num_bytes = await read_int(reader) - payload = await reader.readexactly(num_bytes) - metadata: GraphMetadata | None = None - try: - metadata_obj = pickle.loads(payload) - if isinstance(metadata_obj, GraphMetadata): - metadata = metadata_obj - else: - raise RuntimeError("metadata payload was not GraphMetadata") - except Exception as exc: - logger.warning( - f"Session {session_id} metadata parse failed; ignoring payload: {exc}" - ) + metadata = await self._read_typed_payload( + reader, + GraphMetadata, + log_prefix=f"Session {session_id} metadata", + ) async with self._command_lock: session = self._session_info(session_id) if session is not None and metadata is not None: + collisions = self._metadata_collisions(session_id, metadata) + if collisions: + logger.warning( + "Session %s metadata registration rejected due to component address collision(s): %s", + session_id, + ", ".join(collisions), + ) + return Command.ERROR.value + self._remove_settings_for_session_locked(session_id) session.metadata = metadata + self._apply_session_metadata_settings_locked(session_id, metadata) return Command.COMPLETE.value @@ -574,6 +1586,29 @@ async def _handle_session_snapshot_request( writer.write(snapshot_bytes) writer.write(Command.COMPLETE.value) + async def _handle_session_settings_snapshot_request( + self, writer: asyncio.StreamWriter + ) -> None: + async with self._command_lock: + snapshot = { + component_address: self._settings_current[component_address] + for component_address in sorted(self._settings_current) + } + snapshot_bytes = pickle.dumps(snapshot) + writer.write(uint64_to_bytes(len(snapshot_bytes))) + writer.write(snapshot_bytes) + writer.write(Command.COMPLETE.value) + + async def _handle_session_settings_events_request( + self, writer: asyncio.StreamWriter, after_seq: int + ) -> None: + async with self._command_lock: + events = [event for event in self._settings_events if event.seq > after_seq] + event_bytes = pickle.dumps(events) + writer.write(uint64_to_bytes(len(event_bytes))) + writer.write(event_bytes) + writer.write(Command.COMPLETE.value) + def _clear_session_state(self, session_id: UUID) -> set[str]: notify_topics: set[str] = set() session = self._session_info(session_id) @@ -584,7 +1619,15 @@ def _clear_session_state(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) + self._remove_settings_for_session_locked(session_id) session.metadata = None + if notify_topics: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=list(notify_topics), + source_session_id=str(session_id), + source_process_id=None, + ) return notify_topics def _drop_session(self, session_id: UUID) -> set[str]: @@ -597,8 +1640,16 @@ def _drop_session(self, session_id: UUID) -> set[str]: if self._disconnect_owner(from_topic, to_topic, session_id): notify_topics.add(to_topic) + self._remove_settings_for_session_locked(session_id) session.metadata = None self.clients.pop(session_id, None) + if notify_topics: + self._append_topology_event_locked( + event_type=TopologyEventType.GRAPH_CHANGED, + changed_topics=list(notify_topics), + source_session_id=str(session_id), + source_process_id=None, + ) return notify_topics def _snapshot(self) -> GraphSnapshot: @@ -632,7 +1683,28 @@ def _snapshot(self) -> GraphSnapshot: key=lambda item: str(item[0]), ) } - return GraphSnapshot(graph=graph, edge_owners=edge_owners, sessions=sessions) + processes = { + client_id: SnapshotProcess( + process_id=self._process_key(client_id), + pid=process.pid, + host=process.host, + units=sorted(process.units), + ) + for client_id, process in sorted( + [ + (client_id, info) + for client_id, info in self.clients.items() + if isinstance(info, ProcessInfo) + ], + key=lambda item: str(item[0]), + ) + } + return GraphSnapshot( + graph=graph, + edge_owners=edge_owners, + sessions=sessions, + processes=processes, + ) async def _notify_downstream_for_topic(self, topic: str) -> None: for sub in self._downstream_subs(topic): diff --git a/src/ezmsg/core/messagechannel.py b/src/ezmsg/core/messagechannel.py index 8b50e29..3c86fa5 100644 --- a/src/ezmsg/core/messagechannel.py +++ b/src/ezmsg/core/messagechannel.py @@ -21,6 +21,8 @@ encode_str, close_stream_writer, ) +from .profiling import PROFILES, PROFILE_TIME +from .graphmeta import ProfileChannelType logger = logging.getLogger("ezmsg") @@ -99,6 +101,8 @@ class Channel: _pub_writer: asyncio.StreamWriter _graph_address: AddressType | None _local_backpressure: Backpressure | None + _channel_kind: ProfileChannelType + _lease_start: dict[tuple[UUID, int], int] def __init__( self, @@ -125,6 +129,8 @@ def __init__( self.clients = dict() self._graph_address = graph_address self._local_backpressure = None + self._channel_kind = ProfileChannelType.UNKNOWN + self._lease_start = {} @classmethod async def create( @@ -257,8 +263,10 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: msg_id = await read_int(reader) buf_idx = msg_id % self.num_buffers + channel_kind = ProfileChannelType.UNKNOWN if msg == Command.TX_SHM.value: + channel_kind = ProfileChannelType.SHM shm_name = await read_str(reader) if self.shm is not None and self.shm.name != shm_name: @@ -285,6 +293,7 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: self.cache.put_from_mem(self.shm[buf_idx]) elif msg == Command.TX_TCP.value: + channel_kind = ProfileChannelType.TCP buf_size = await read_int(reader) obj_bytes = await reader.readexactly(buf_size) assert MessageMarshal.msg_id(obj_bytes) == msg_id @@ -293,6 +302,8 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: else: raise ValueError(f"unimplemented data telemetry: {msg}") + self._set_channel_kind(channel_kind) + if not self._notify_clients(msg_id): # Nobody is listening; need to ack! self.cache.release(msg_id) @@ -310,13 +321,31 @@ async def _publisher_connection(self, reader: asyncio.StreamReader) -> None: logger.debug(f"disconnected: channel:{self.id} -> pub:{self.pub_id}") + def _set_channel_kind(self, kind: ProfileChannelType) -> None: + if self._channel_kind == ProfileChannelType.UNKNOWN: + self._channel_kind = kind + elif self._channel_kind != kind: + logger.warning( + "Channel %s observed channel kind change: %s -> %s", + self.id, + self._channel_kind.value, + kind.value, + ) + self._channel_kind = kind + + @property + def channel_kind(self) -> ProfileChannelType: + return self._channel_kind + def _notify_clients(self, msg_id: int) -> bool: """notify interested clients and return true if any were notified""" buf_idx = msg_id % self.num_buffers + now_ns = PROFILE_TIME() for client_id, queue in self.clients.items(): if queue is None: continue # queue is none if this is the pub self.backpressure.lease(client_id, buf_idx) + self._lease_start[(client_id, msg_id)] = now_ns queue.put_nowait((self.pub_id, msg_id)) return not self.backpressure.available(buf_idx) @@ -331,6 +360,7 @@ def put_local(self, msg_id: int, msg: typing.Any) -> None: ) buf_idx = msg_id % self.num_buffers + self._set_channel_kind(ProfileChannelType.LOCAL) if self._notify_clients(msg_id): self.cache.put_local(msg, msg_id) self._local_backpressure.lease(self.id, buf_idx) @@ -379,6 +409,18 @@ def _release_backpressure(self, msg_id: int, client_id: UUID) -> None: :param client_id: UUID of client releasing this message :type client_id: UUID """ + now_ns = PROFILE_TIME() + lease = self._lease_start.pop((client_id, msg_id), None) + if lease is not None: + start_ns = lease + PROFILES.subscriber_attributed_backpressure( + client_id, + now_ns, + now_ns - start_ns, + self._channel_kind, + msg_seq=msg_id, + ) + buf_idx = msg_id % self.num_buffers self.backpressure.free(client_id, buf_idx) if self.backpressure.buffers[buf_idx].is_empty: @@ -434,6 +476,11 @@ def unregister_client(self, client_id: UUID) -> None: queue.put_nowait((pub_id, msg_id)) self.backpressure.free(client_id) + stale = [ + key for key in self._lease_start.keys() if key[0] == client_id + ] + for key in stale: + self._lease_start.pop(key, None) elif client_id == self.pub_id and self._local_backpressure is not None: self._local_backpressure.free(self.id) diff --git a/src/ezmsg/core/netprotocol.py b/src/ezmsg/core/netprotocol.py index dca5858..ff6bf24 100644 --- a/src/ezmsg/core/netprotocol.py +++ b/src/ezmsg/core/netprotocol.py @@ -176,6 +176,18 @@ class SessionInfo(ClientInfo): metadata: GraphMetadata | None = None +@dataclass +class ProcessInfo(ClientInfo): + """ + Process-scoped control-plane client information. + """ + + pid: int | None = None + host: str | None = None + units: set[str] = field(default_factory=set) + write_lock: asyncio.Lock = field(default_factory=asyncio.Lock, init=False) + + def uint64_to_bytes(i: int) -> bytes: """ Convert a 64-bit unsigned integer to bytes. @@ -318,6 +330,22 @@ def _generate_next_value_(name, start, count, last_values) -> bytes: SESSION_CLEAR = enum.auto() SESSION_REGISTER = enum.auto() SESSION_SNAPSHOT = enum.auto() + SESSION_SETTINGS_SNAPSHOT = enum.auto() + SESSION_SETTINGS_EVENTS = enum.auto() + SESSION_SETTINGS_SUBSCRIBE = enum.auto() + SESSION_TOPOLOGY_SUBSCRIBE = enum.auto() + SESSION_PROFILING_SUBSCRIBE = enum.auto() + SESSION_PROCESS_REQUEST = enum.auto() + + # Backend Process Control Commands + PROCESS = enum.auto() + PROCESS_REGISTER = enum.auto() + PROCESS_UPDATE_OWNERSHIP = enum.auto() + PROCESS_SETTINGS_UPDATE = enum.auto() + PROCESS_PROFILING_TRACE_UPDATE = enum.auto() + PROCESS_ROUTE_REQUEST = enum.auto() + PROCESS_ROUTE_RESPONSE = enum.auto() + ERROR = enum.auto() def create_socket( diff --git a/src/ezmsg/core/processclient.py b/src/ezmsg/core/processclient.py new file mode 100644 index 0000000..97c12ce --- /dev/null +++ b/src/ezmsg/core/processclient.py @@ -0,0 +1,483 @@ +import asyncio +import logging +import os +import pickle +import socket +import time + +from uuid import UUID +from contextlib import suppress +from collections.abc import Awaitable, Callable + +from .graphmeta import ( + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, + ProcessControlErrorCode, + ProcessControlOperation, + ProcessControlRequest, + ProcessControlResponse, + ProfilingTraceControl, + ProcessPing, + ProcessRegistration, + ProcessStats, + ProcessOwnershipUpdate, + ProcessSettingsUpdate, + SettingsSnapshotValue, +) +from .profiling import PROFILES +from .graphserver import GraphService +from .netprotocol import ( + AddressType, + Command, + close_stream_writer, + read_int, + read_str, + uint64_to_bytes, +) + +logger = logging.getLogger("ezmsg") + + +class ProcessControlClient: + _graph_address: AddressType | None + _client_id: UUID | None + _reader: asyncio.StreamReader | None + _writer: asyncio.StreamWriter | None + _write_lock: asyncio.Lock + _ack_queue: asyncio.Queue[bytes] + _io_task: asyncio.Task[None] | None + _request_handler: Callable[ + [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] + ] | None + _owned_units: set[str] + _trace_push_task: asyncio.Task[None] | None + _trace_push_interval_s: float + _trace_push_max_samples: int + + def __init__(self, graph_address: AddressType | None = None) -> None: + self._graph_address = graph_address + self._client_id = None + self._reader = None + self._writer = None + self._write_lock = asyncio.Lock() + self._ack_queue = asyncio.Queue() + self._io_task = None + self._request_handler = None + self._owned_units = set() + self._trace_push_task = None + self._trace_push_interval_s = float( + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_INTERVAL_S", "0.02") + ) + self._trace_push_max_samples = int( + os.environ.get("EZMSG_PROFILE_TRACE_PUSH_MAX_SAMPLES", "5000") + ) + + def _require_client_id(self) -> UUID: + if self._client_id is None: + raise RuntimeError("Process control connection is not active") + return self._client_id + + @property + def process_id(self) -> UUID: + return self._require_client_id() + + @property + def client_id(self) -> UUID | None: + return self._client_id + + async def connect(self) -> None: + if self._writer is not None: + return + + reader, writer = await GraphService(self._graph_address).open_connection() + writer.write(Command.PROCESS.value) + await writer.drain() + + client_id = UUID(await read_str(reader)) + response = await reader.read(1) + if response != Command.COMPLETE.value: + await close_stream_writer(writer) + raise RuntimeError("Failed to create process control connection") + + self._client_id = client_id + PROFILES.set_process_id(client_id) + self._reader = reader + self._writer = writer + self._io_task = asyncio.create_task( + self._io_loop(), + name=f"process-control-{client_id}", + ) + + def set_request_handler( + self, + handler: Callable[ + [ProcessControlRequest], ProcessControlResponse | Awaitable[ProcessControlResponse] + ] + | None, + ) -> None: + self._request_handler = handler + + async def register(self, units: list[str]) -> None: + await self.connect() + normalized_units = sorted(set(units)) + payload = ProcessRegistration( + pid=os.getpid(), + host=socket.gethostname(), + units=normalized_units, + ) + await self._payload_command(Command.PROCESS_REGISTER, payload) + self._owned_units = set(normalized_units) + + async def update_ownership( + self, + added_units: list[str] | None = None, + removed_units: list[str] | None = None, + ) -> None: + await self.connect() + added = sorted(set(added_units or [])) + removed = sorted(set(removed_units or [])) + payload = ProcessOwnershipUpdate( + added_units=added, + removed_units=removed, + ) + await self._payload_command(Command.PROCESS_UPDATE_OWNERSHIP, payload) + self._owned_units.update(added) + self._owned_units.difference_update(removed) + + async def report_settings_update( + self, + component_address: str, + value: SettingsSnapshotValue, + timestamp: float | None = None, + ) -> None: + await self.connect() + payload = ProcessSettingsUpdate( + component_address=component_address, + value=value, + timestamp=timestamp if timestamp is not None else time.time(), + ) + await self._payload_command(Command.PROCESS_SETTINGS_UPDATE, payload) + + async def close(self) -> None: + writer = self._writer + if writer is None: + return + + trace_task = self._trace_push_task + self._trace_push_task = None + if trace_task is not None: + trace_task.cancel() + with suppress(asyncio.CancelledError): + await trace_task + + io_task = self._io_task + self._io_task = None + if io_task is not None: + io_task.cancel() + with suppress(asyncio.CancelledError): + await io_task + + self._reader = None + self._writer = None + self._client_id = None + await close_stream_writer(writer) + + async def _payload_command(self, command: Command, payload_obj: object) -> None: + await self._write_payload(command, payload_obj, expect_complete=True) + + async def _write_payload( + self, + command: Command, + payload_obj: object, + *, + expect_complete: bool, + ) -> None: + reader = self._reader + writer = self._writer + if reader is None or writer is None: + raise RuntimeError("Process control connection is not active") + + payload = pickle.dumps(payload_obj) + async with self._write_lock: + writer.write(command.value) + writer.write(uint64_to_bytes(len(payload))) + writer.write(payload) + await writer.drain() + + if not expect_complete: + return + + try: + response = await asyncio.wait_for(self._ack_queue.get(), timeout=5.0) + except asyncio.TimeoutError as exc: + raise RuntimeError( + f"Timed out waiting for response to process control command: {command.name}" + ) from exc + + if response != Command.COMPLETE.value: + if response == Command.ERROR.value: + raise RuntimeError( + f"Process control command failed: {command.name}" + ) + raise RuntimeError( + f"Unexpected response to process control command: {command.name}" + ) + + async def _io_loop(self) -> None: + reader = self._reader + writer = self._writer + if reader is None or writer is None: + return + + try: + while True: + req = await reader.read(1) + if not req: + break + + if req in (Command.COMPLETE.value, Command.ERROR.value): + self._ack_queue.put_nowait(req) + continue + + if req != Command.PROCESS_ROUTE_REQUEST.value: + logger.warning( + "Process control %s received unknown command: %s", + self._client_id, + req, + ) + continue + + payload_size = await read_int(reader) + payload = await reader.readexactly(payload_size) + request: ProcessControlRequest | None = None + try: + request_obj = pickle.loads(payload) + if isinstance(request_obj, ProcessControlRequest): + request = request_obj + else: + raise RuntimeError( + "process route request payload was not ProcessControlRequest" + ) + except Exception as exc: + logger.warning( + "Process control %s failed to parse route request: %s", + self._client_id, + exc, + ) + + if request is None: + continue + + response = await self._handle_route_request(request) + await self._write_payload( + Command.PROCESS_ROUTE_RESPONSE, + response, + expect_complete=False, + ) + + except asyncio.CancelledError: + raise + except (ConnectionResetError, BrokenPipeError) as exc: + logger.debug(f"Process control {self._client_id} disconnected: {exc}") + + async def _handle_route_request( + self, request: ProcessControlRequest + ) -> ProcessControlResponse: + operation: ProcessControlOperation | None = None + if isinstance(request.operation, ProcessControlOperation): + operation = request.operation + elif isinstance(request.operation, str): + with suppress(ValueError): + operation = ProcessControlOperation(request.operation) + + if operation == ProcessControlOperation.PING: + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + ProcessPing( + process_id=self.process_id, + pid=os.getpid(), + host=socket.gethostname(), + timestamp=time.time(), + ) + ), + process_id=self.process_id, + ) + + if operation == ProcessControlOperation.GET_PROCESS_STATS: + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + ProcessStats( + process_id=self.process_id, + pid=os.getpid(), + host=socket.gethostname(), + owned_units=sorted(self._owned_units), + timestamp=time.time(), + ) + ), + process_id=self.process_id, + ) + + if operation == ProcessControlOperation.GET_PROFILING_SNAPSHOT: + snapshot: ProcessProfilingSnapshot = PROFILES.snapshot() + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(snapshot), + process_id=self.process_id, + ) + + if operation == ProcessControlOperation.SET_PROFILING_TRACE: + control: ProfilingTraceControl | None = None + try: + if request.payload is not None: + control_obj = pickle.loads(request.payload) + if isinstance(control_obj, ProfilingTraceControl): + control = control_obj + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Invalid profiling trace control payload: {exc}", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self.process_id, + ) + + if control is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Missing profiling trace control payload", + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self.process_id, + ) + + PROFILES.set_trace_control(control) + if control.enabled: + await self._ensure_trace_push_task() + else: + await self._cancel_trace_push_task() + + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + process_id=self.process_id, + ) + + if operation == ProcessControlOperation.GET_PROFILING_TRACE_BATCH: + max_samples = 1000 + if request.payload is not None: + try: + max_samples_obj = pickle.loads(request.payload) + if isinstance(max_samples_obj, int): + max_samples = max(1, max_samples_obj) + except Exception: + pass + + batch: ProcessProfilingTraceBatch = PROFILES.trace_batch( + max_samples=max_samples + ) + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps(batch), + process_id=self.process_id, + ) + + if self._request_handler is None: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"Unsupported process control operation: {request.operation}", + error_code=ProcessControlErrorCode.HANDLER_NOT_CONFIGURED, + process_id=self.process_id, + ) + + try: + result = self._request_handler(request) + if asyncio.iscoroutine(result): + result = await result + except Exception as exc: + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=f"process request handler failed: {exc}", + error_code=ProcessControlErrorCode.HANDLER_ERROR, + process_id=self.process_id, + ) + + if not isinstance(result, ProcessControlResponse): + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + "process request handler returned invalid response type: " + f"{type(result).__name__}" + ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self.process_id, + ) + + if result.request_id != request.request_id: + result = ProcessControlResponse( + request_id=request.request_id, + ok=False, + error=( + "process request handler returned mismatched request_id: " + f"{result.request_id}" + ), + error_code=ProcessControlErrorCode.INVALID_RESPONSE, + process_id=self.process_id, + ) + + if result.process_id is None: + result.process_id = self.process_id + + return result + + async def _ensure_trace_push_task(self) -> None: + task = self._trace_push_task + if task is not None and not task.done(): + return + self._trace_push_task = asyncio.create_task( + self._trace_push_loop(), + name=f"proc-trace-push-{self.process_id}", + ) + + async def _cancel_trace_push_task(self) -> None: + task = self._trace_push_task + self._trace_push_task = None + if task is None: + return + task.cancel() + with suppress(asyncio.CancelledError): + await task + + async def _trace_push_loop(self) -> None: + try: + while True: + await asyncio.sleep(max(0.01, self._trace_push_interval_s)) + batch: ProcessProfilingTraceBatch = PROFILES.trace_batch( + max_samples=max(1, self._trace_push_max_samples) + ) + if len(batch.samples) > 0: + await self._write_payload( + Command.PROCESS_PROFILING_TRACE_UPDATE, + batch, + expect_complete=False, + ) + + if not PROFILES.trace_enabled(): + break + except asyncio.CancelledError: + raise + except (ConnectionResetError, BrokenPipeError): + logger.debug("Process trace push loop disconnected") + except Exception as exc: + logger.warning(f"Process trace push loop failed: {exc}") + finally: + if asyncio.current_task() is self._trace_push_task: + self._trace_push_task = None diff --git a/src/ezmsg/core/profiling.py b/src/ezmsg/core/profiling.py new file mode 100644 index 0000000..e38ce89 --- /dev/null +++ b/src/ezmsg/core/profiling.py @@ -0,0 +1,545 @@ +import os +import socket +import time +import heapq +from collections import deque +from dataclasses import dataclass, field +from typing import Callable, TypeAlias +from uuid import UUID + +from .graphmeta import ( + ProcessProfilingSnapshot, + ProcessProfilingTraceBatch, + ProfileChannelType, + ProfilingTraceControl, + ProfilingTraceSample, + PublisherProfileSnapshot, + SubscriberProfileSnapshot, +) + + +WINDOW_SECONDS = float(os.environ.get("EZMSG_PROFILE_WINDOW_SECONDS", "10.0")) +BUCKET_SECONDS = float(os.environ.get("EZMSG_PROFILE_BUCKET_SECONDS", "0.1")) +TRACE_MAX_SAMPLES = int(os.environ.get("EZMSG_PROFILE_TRACE_MAX_SAMPLES", "10000")) +# Must return monotonic nanoseconds so *_ns metrics remain unit-consistent. +PROFILE_TIME_TYPE: TypeAlias = Callable[[], int] +PROFILE_TIME: PROFILE_TIME_TYPE = time.perf_counter_ns + + +def _endpoint_id(topic: str, id: UUID) -> str: + return f"{topic}:{id}" + + +@dataclass +class _Rolling: + window_seconds: float = WINDOW_SECONDS + bucket_seconds: float = BUCKET_SECONDS + count: list[int] = field(default_factory=list) + value_sum: list[int] = field(default_factory=list) + max_value: list[int] = field(default_factory=list) + _num_buckets: int = 0 + _bucket_ns: int = 0 + _last_bucket: int | None = None + + def __post_init__(self) -> None: + self._num_buckets = max(1, int(self.window_seconds / self.bucket_seconds)) + self._bucket_ns = max(1, int(self.bucket_seconds * 1e9)) + self.count = [0 for _ in range(self._num_buckets)] + self.value_sum = [0 for _ in range(self._num_buckets)] + self.max_value = [0 for _ in range(self._num_buckets)] + + def _bucket(self, ts_ns: int) -> int: + return (ts_ns // self._bucket_ns) % self._num_buckets + + def _advance(self, ts_ns: int) -> int: + bucket = self._bucket(ts_ns) + if self._last_bucket is None: + self._last_bucket = bucket + return bucket + if bucket == self._last_bucket: + return bucket + idx = (self._last_bucket + 1) % self._num_buckets + while idx != bucket: + self.count[idx] = 0 + self.value_sum[idx] = 0 + self.max_value[idx] = 0 + idx = (idx + 1) % self._num_buckets + self.count[bucket] = 0 + self.value_sum[bucket] = 0 + self.max_value[bucket] = 0 + self._last_bucket = bucket + return bucket + + def add(self, ts_ns: int, value: int) -> None: + idx = self._advance(ts_ns) + self.count[idx] += 1 + self.value_sum[idx] += value + if value > self.max_value[idx]: + self.max_value[idx] = value + + def count_total(self) -> int: + return sum(self.count) + + def sum_total(self) -> int: + return sum(self.value_sum) + + def max_total(self) -> int: + return max(self.max_value) if self.max_value else 0 + + def avg(self) -> float: + c = self.count_total() + if c == 0: + return 0.0 + return float(self.sum_total()) / float(c) + + +@dataclass +class _PublisherMetrics: + topic: str + endpoint_id: str + num_buffers: int + messages_published_total: int = 0 + backpressure_wait_ns_total: int = 0 + inflight_messages_current: int = 0 + _last_publish_ts_ns: int | None = None + _publish_delta: _Rolling = field(default_factory=_Rolling) + _publish_count: _Rolling = field(default_factory=lambda: _Rolling()) + _backpressure_wait: _Rolling = field(default_factory=_Rolling) + _inflight: _Rolling = field(default_factory=_Rolling) + trace_enabled: bool = False + trace_sample_mod: int = 1 + trace_metrics: set[str] | None = None + _trace_counter: int = 0 + trace_samples: deque[ProfilingTraceSample] = field( + default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) + ) + + def _trace_metric_enabled(self, metric: str) -> bool: + return self.trace_metrics is None or metric in self.trace_metrics + + def record_publish( + self, ts_ns: int, inflight: int, msg_seq: int | None = None + ) -> None: + self.messages_published_total += 1 + self._publish_count.add(ts_ns, 1) + publish_delta_ns = 0 + if self._last_publish_ts_ns is not None: + publish_delta_ns = ts_ns - self._last_publish_ts_ns + self._publish_delta.add(ts_ns, publish_delta_ns) + self._last_publish_ts_ns = ts_ns + self.sample_inflight(ts_ns, inflight) + self._trace_counter += 1 + if ( + self.trace_enabled + and self._trace_metric_enabled("publish_delta_ns") + and (self._trace_counter % max(1, self.trace_sample_mod) == 0) + ): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="publish_delta_ns", + value=float(publish_delta_ns), + sample_seq=msg_seq, + ) + ) + + def record_backpressure_wait( + self, ts_ns: int, wait_ns: int, msg_seq: int | None = None + ) -> None: + self.backpressure_wait_ns_total += wait_ns + self._backpressure_wait.add(ts_ns, wait_ns) + if self.trace_enabled and self._trace_metric_enabled("backpressure_wait_ns"): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="backpressure_wait_ns", + value=float(wait_ns), + sample_seq=msg_seq, + ) + ) + + def sample_inflight(self, ts_ns: int, inflight: int) -> None: + self.inflight_messages_current = inflight + self._inflight.add(ts_ns, inflight) + + def snapshot(self) -> PublisherProfileSnapshot: + window_msgs = self._publish_count.count_total() + return PublisherProfileSnapshot( + endpoint_id=self.endpoint_id, + topic=self.topic, + messages_published_total=self.messages_published_total, + messages_published_window=window_msgs, + publish_delta_ns_avg_window=self._publish_delta.avg(), + publish_rate_hz_window=float(window_msgs) / max(WINDOW_SECONDS, 1e-9), + inflight_messages_current=self.inflight_messages_current, + num_buffers=self.num_buffers, + inflight_messages_peak_window=self._inflight.max_total(), + backpressure_wait_ns_total=self.backpressure_wait_ns_total, + backpressure_wait_ns_window=self._backpressure_wait.sum_total(), + timestamp=float(PROFILE_TIME()), + ) + + +@dataclass +class _SubscriberMetrics: + topic: str + endpoint_id: str + messages_received_total: int = 0 + lease_time_ns_total: int = 0 + user_span_ns_total: int = 0 + attributable_backpressure_ns_total: int = 0 + attributable_backpressure_events_total: int = 0 + channel_kind_last: ProfileChannelType = ProfileChannelType.UNKNOWN + _recv_count: _Rolling = field(default_factory=lambda: _Rolling()) + _lease_time: _Rolling = field(default_factory=_Rolling) + _user_span: _Rolling = field(default_factory=_Rolling) + _attrib_bp: _Rolling = field(default_factory=_Rolling) + trace_enabled: bool = False + trace_sample_mod: int = 1 + trace_metrics: set[str] | None = None + _trace_counter: int = 0 + trace_samples: deque[ProfilingTraceSample] = field( + default_factory=lambda: deque(maxlen=TRACE_MAX_SAMPLES) + ) + + def _trace_metric_enabled(self, metric: str) -> bool: + return self.trace_metrics is None or metric in self.trace_metrics + + def record_receive( + self, + ts_ns: int, + lease_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, + ) -> None: + self.messages_received_total += 1 + self.lease_time_ns_total += lease_ns + self.channel_kind_last = channel_kind + self._recv_count.add(ts_ns, 1) + self._lease_time.add(ts_ns, lease_ns) + self._trace_counter += 1 + if ( + self.trace_enabled + and self._trace_metric_enabled("lease_time_ns") + and (self._trace_counter % max(1, self.trace_sample_mod) == 0) + ): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="lease_time_ns", + value=float(lease_ns), + channel_kind=channel_kind, + sample_seq=msg_seq, + ) + ) + + def record_user_span( + self, ts_ns: int, span_ns: int, label: str | None, msg_seq: int | None = None + ) -> None: + self.user_span_ns_total += span_ns + self._user_span.add(ts_ns, span_ns) + if self.trace_enabled and self._trace_metric_enabled("user_span_ns"): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic if label is None else f"{self.topic}:{label}", + metric="user_span_ns", + value=float(span_ns), + channel_kind=self.channel_kind_last, + sample_seq=msg_seq, + ) + ) + + def record_attributed_backpressure( + self, + ts_ns: int, + duration_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, + ) -> None: + self.attributable_backpressure_ns_total += duration_ns + self.attributable_backpressure_events_total += 1 + self.channel_kind_last = channel_kind + self._attrib_bp.add(ts_ns, duration_ns) + if self.trace_enabled and self._trace_metric_enabled("attributable_backpressure_ns"): + self.trace_samples.append( + ProfilingTraceSample( + timestamp=float(PROFILE_TIME()), + endpoint_id=self.endpoint_id, + topic=self.topic, + metric="attributable_backpressure_ns", + value=float(duration_ns), + channel_kind=channel_kind, + sample_seq=msg_seq, + ) + ) + + def snapshot(self) -> SubscriberProfileSnapshot: + recv_count = self._recv_count.count_total() + user_count = self._user_span.count_total() + return SubscriberProfileSnapshot( + endpoint_id=self.endpoint_id, + topic=self.topic, + messages_received_total=self.messages_received_total, + messages_received_window=recv_count, + lease_time_ns_total=self.lease_time_ns_total, + lease_time_ns_avg_window=self._lease_time.avg(), + user_span_ns_total=self.user_span_ns_total, + user_span_ns_avg_window=( + float(self._user_span.sum_total()) / float(user_count) + if user_count > 0 + else 0.0 + ), + attributable_backpressure_ns_total=self.attributable_backpressure_ns_total, + attributable_backpressure_ns_window=self._attrib_bp.sum_total(), + attributable_backpressure_events_total=self.attributable_backpressure_events_total, + channel_kind_last=self.channel_kind_last, + timestamp=float(PROFILE_TIME()), + ) + + +class ProfileRegistry: + def __init__(self) -> None: + self._process_id = UUID(int=0) + self._pid = os.getpid() + self._host = socket.gethostname() + self._publishers: dict[UUID, _PublisherMetrics] = {} + self._subscribers: dict[UUID, _SubscriberMetrics] = {} + self._default_trace_control = ProfilingTraceControl(enabled=False) + self._trace_control_expires_ns: int | None = None + + def set_process_id(self, process_id: UUID, *, reset: bool = False) -> None: + if reset: + self._publishers.clear() + self._subscribers.clear() + self._default_trace_control = ProfilingTraceControl(enabled=False) + self._trace_control_expires_ns = None + self._process_id = process_id + + def register_publisher(self, pub_id: UUID, topic: str, num_buffers: int) -> None: + metric = _PublisherMetrics( + topic=topic, + endpoint_id=_endpoint_id(topic, pub_id), + num_buffers=max(1, int(num_buffers)), + ) + self._publishers[pub_id] = metric + self._apply_trace_control_to_publisher(metric) + + def unregister_publisher(self, pub_id: UUID) -> None: + self._publishers.pop(pub_id, None) + + def register_subscriber(self, sub_id: UUID, topic: str) -> None: + metric = _SubscriberMetrics( + topic=topic, + endpoint_id=_endpoint_id(topic, sub_id), + ) + self._subscribers[sub_id] = metric + self._apply_trace_control_to_subscriber(metric) + + def unregister_subscriber(self, sub_id: UUID) -> None: + self._subscribers.pop(sub_id, None) + + def publisher_publish( + self, pub_id: UUID, ts_ns: int, inflight: int, msg_seq: int | None = None + ) -> None: + self._expire_trace_control_if_needed(ts_ns) + metric = self._publishers.get(pub_id) + if metric is not None: + metric.record_publish(ts_ns, inflight, msg_seq) + + def publisher_backpressure_wait( + self, pub_id: UUID, ts_ns: int, wait_ns: int, msg_seq: int | None = None + ) -> None: + self._expire_trace_control_if_needed(ts_ns) + metric = self._publishers.get(pub_id) + if metric is not None: + metric.record_backpressure_wait(ts_ns, wait_ns, msg_seq) + + def publisher_sample_inflight(self, pub_id: UUID, ts_ns: int, inflight: int) -> None: + metric = self._publishers.get(pub_id) + if metric is not None: + metric.sample_inflight(ts_ns, inflight) + + def subscriber_receive( + self, + sub_id: UUID, + ts_ns: int, + lease_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, + ) -> None: + self._expire_trace_control_if_needed(ts_ns) + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_receive(ts_ns, lease_ns, channel_kind, msg_seq) + + def subscriber_user_span( + self, + sub_id: UUID, + ts_ns: int, + span_ns: int, + label: str | None, + msg_seq: int | None = None, + ) -> None: + self._expire_trace_control_if_needed(ts_ns) + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_user_span(ts_ns, span_ns, label, msg_seq) + + def subscriber_attributed_backpressure( + self, + sub_id: UUID, + ts_ns: int, + duration_ns: int, + channel_kind: ProfileChannelType, + msg_seq: int | None = None, + ) -> None: + self._expire_trace_control_if_needed(ts_ns) + metric = self._subscribers.get(sub_id) + if metric is not None: + metric.record_attributed_backpressure( + ts_ns, duration_ns, channel_kind, msg_seq + ) + + def snapshot(self) -> ProcessProfilingSnapshot: + return ProcessProfilingSnapshot( + process_id=self._process_id, + pid=self._pid, + host=self._host, + window_seconds=WINDOW_SECONDS, + timestamp=float(PROFILE_TIME()), + publishers={ + metric.endpoint_id: metric.snapshot() + for metric in self._publishers.values() + }, + subscribers={ + metric.endpoint_id: metric.snapshot() + for metric in self._subscribers.values() + }, + ) + + def set_trace_control(self, control: ProfilingTraceControl) -> None: + # Changing filters/mode should start from a clean trace buffer so new + # consumers do not receive stale samples from an old control scope. + self._clear_trace_samples() + self._default_trace_control = control + if control.enabled and control.ttl_seconds is not None: + self._trace_control_expires_ns = PROFILE_TIME() + max( + 0, int(control.ttl_seconds * 1e9) + ) + else: + self._trace_control_expires_ns = None + + for metric in self._publishers.values(): + self._apply_trace_control_to_publisher(metric) + + for metric in self._subscribers.values(): + self._apply_trace_control_to_subscriber(metric) + + def trace_batch(self, max_samples: int = 1000) -> ProcessProfilingTraceBatch: + self._expire_trace_control_if_needed() + samples: list[ProfilingTraceSample] = [] + limit = max(1, int(max_samples)) + + queues: list[deque[ProfilingTraceSample]] = [] + for metric in self._publishers.values(): + if metric.trace_samples: + queues.append(metric.trace_samples) + for metric in self._subscribers.values(): + if metric.trace_samples: + queues.append(metric.trace_samples) + + if len(queues) == 1: + queue = queues[0] + while queue and len(samples) < limit: + samples.append(queue.popleft()) + elif len(queues) > 1: + heap: list[tuple[float, int, int]] = [] + for idx, queue in enumerate(queues): + sample = queue[0] + # Include sample_seq to keep deterministic ordering when timestamps tie. + seq = sample.sample_seq if sample.sample_seq is not None else -1 + heapq.heappush(heap, (sample.timestamp, seq, idx)) + + while heap and len(samples) < limit: + _timestamp, _seq, queue_idx = heapq.heappop(heap) + queue = queues[queue_idx] + if not queue: + continue + sample = queue.popleft() + samples.append(sample) + if queue: + nxt = queue[0] + nxt_seq = nxt.sample_seq if nxt.sample_seq is not None else -1 + heapq.heappush(heap, (nxt.timestamp, nxt_seq, queue_idx)) + + return ProcessProfilingTraceBatch( + process_id=self._process_id, + pid=self._pid, + host=self._host, + timestamp=float(PROFILE_TIME()), + samples=samples, + ) + + def trace_enabled(self) -> bool: + self._expire_trace_control_if_needed() + return self._default_trace_control.enabled + + def _expire_trace_control_if_needed(self, now_ns: int | None = None) -> None: + expires_ns = self._trace_control_expires_ns + if expires_ns is None: + return + ts_ns = now_ns if now_ns is not None else PROFILE_TIME() + if ts_ns < expires_ns: + return + self.set_trace_control(ProfilingTraceControl(enabled=False)) + + def _apply_trace_control_to_publisher(self, metric: _PublisherMetrics) -> None: + control = self._default_trace_control + sample_mod = max(1, control.sample_mod) + pub_topics = set(control.publisher_topics or []) + pub_endpoint_ids = set(control.publisher_endpoint_ids or []) + trace_metrics = ( + set(control.metrics) if control.metrics is not None else None + ) + enabled = control.enabled + if enabled and pub_topics and metric.topic not in pub_topics: + enabled = False + if enabled and pub_endpoint_ids and metric.endpoint_id not in pub_endpoint_ids: + enabled = False + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + metric.trace_metrics = trace_metrics + + def _apply_trace_control_to_subscriber(self, metric: _SubscriberMetrics) -> None: + control = self._default_trace_control + sample_mod = max(1, control.sample_mod) + sub_topics = set(control.subscriber_topics or []) + sub_endpoint_ids = set(control.subscriber_endpoint_ids or []) + trace_metrics = ( + set(control.metrics) if control.metrics is not None else None + ) + enabled = control.enabled + if enabled and sub_topics and metric.topic not in sub_topics: + enabled = False + if enabled and sub_endpoint_ids and metric.endpoint_id not in sub_endpoint_ids: + enabled = False + metric.trace_enabled = enabled + metric.trace_sample_mod = sample_mod + metric.trace_metrics = trace_metrics + + def _clear_trace_samples(self) -> None: + for metric in self._publishers.values(): + metric.trace_samples.clear() + for metric in self._subscribers.values(): + metric.trace_samples.clear() + + +PROFILES = ProfileRegistry() diff --git a/src/ezmsg/core/pubclient.py b/src/ezmsg/core/pubclient.py index f9c4295..cf5f2ce 100644 --- a/src/ezmsg/core/pubclient.py +++ b/src/ezmsg/core/pubclient.py @@ -13,6 +13,7 @@ from .channelmanager import CHANNELS from .messagechannel import Channel from .messagemarshal import MessageMarshal, UninitializedMemory +from .profiling import PROFILES, PROFILE_TIME from .netprotocol import ( Address, @@ -230,6 +231,7 @@ def __init__( self._force_tcp = force_tcp self._last_backpressure_event = -1 self._graph_address = graph_address + PROFILES.register_publisher(self.id, self.topic, self._num_buffers) @property def log_name(self) -> str: @@ -243,6 +245,7 @@ def close(self) -> None: and all subscriber handling tasks. """ self._graph_task.cancel() + PROFILES.unregister_publisher(self.id) self._shm.close() self._connection_task.cancel() for task in self._channel_tasks.values(): @@ -369,12 +372,18 @@ async def _handle_channel( elif msg == Command.RX_ACK.value: msg_id = await read_int(reader) self._backpressure.free(info.id, msg_id % self._num_buffers) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) except (ConnectionResetError, BrokenPipeError): logger.debug(f"Publisher {self.id}: Channel {info.id} connection fail") finally: self._backpressure.free(info.id) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) await close_stream_writer(self._channels[info.id].writer) del self._channels[info.id] @@ -434,7 +443,15 @@ async def broadcast(self, obj: Any) -> None: if BACKPRESSURE_WARNING and (delta > BACKPRESSURE_REFRACTORY): logger.warning(f"{self.topic} under subscriber backpressure!") self._last_backpressure_event = time.time() + wait_start_ns = PROFILE_TIME() await self._backpressure.wait(buf_idx) + wait_end_ns = PROFILE_TIME() + PROFILES.publisher_backpressure_wait( + self.id, + wait_end_ns, + wait_end_ns - wait_start_ns, + msg_seq=self._msg_id, + ) # Get local channel and put variable there for local tx self._local_channel.put_local(self._msg_id, obj) @@ -502,10 +519,20 @@ async def broadcast(self, obj: Any) -> None: channel.writer.write(msg) await channel.writer.drain() self._backpressure.lease(channel.id, buf_idx) + PROFILES.publisher_sample_inflight( + self.id, PROFILE_TIME(), self._backpressure.pressure + ) except (ConnectionResetError, BrokenPipeError): logger.debug( f"Publisher {self.id}: Channel {channel.id} connection fail" ) + now_ns = PROFILE_TIME() + PROFILES.publisher_publish( + self.id, + now_ns, + self._backpressure.pressure, + msg_seq=self._msg_id, + ) self._msg_id += 1 diff --git a/src/ezmsg/core/settingsmeta.py b/src/ezmsg/core/settingsmeta.py new file mode 100644 index 0000000..1260e8e --- /dev/null +++ b/src/ezmsg/core/settingsmeta.py @@ -0,0 +1,329 @@ +from __future__ import annotations + +from dataclasses import MISSING, asdict, fields as dataclass_fields, is_dataclass +import enum +from collections.abc import Mapping +from typing import Any, get_args, get_origin + +from .graphmeta import SettingsFieldMetadata, SettingsSchemaMetadata + + +def _type_name(tp: object) -> str: + if isinstance(tp, type): + return f"{tp.__module__}.{tp.__qualname__}" + return str(tp) + + +def _sanitize(value: Any) -> Any: + if value is None or isinstance(value, (bool, int, float, str)): + return value + if isinstance(value, enum.Enum): + return _sanitize(value.value) + if isinstance(value, Mapping): + return {str(key): _sanitize(val) for key, val in value.items()} + if isinstance(value, (list, tuple, set, frozenset)): + return [_sanitize(val) for val in value] + if is_dataclass(value): + try: + return _sanitize(asdict(value)) + except Exception: + return repr(value) + return repr(value) + + +def settings_structured_value(value: object) -> dict[str, Any] | None: + if value is None: + return None + + if is_dataclass(value): + try: + asdict_value = asdict(value) + if isinstance(asdict_value, dict): + return _sanitize(asdict_value) + except Exception: + pass + + if hasattr(value, "model_dump") and callable(getattr(value, "model_dump")): + try: + dumped = value.model_dump() # type: ignore[attr-defined] + if isinstance(dumped, dict): + return _sanitize(dumped) + except Exception: + pass + + if hasattr(value, "dict") and callable(getattr(value, "dict")): + try: + dumped = value.dict() # type: ignore[attr-defined] + if isinstance(dumped, dict): + return _sanitize(dumped) + except Exception: + pass + + if isinstance(value, Mapping): + return _sanitize(dict(value)) + + if hasattr(value, "param"): + param_ns = getattr(value, "param") + if hasattr(param_ns, "values") and callable(param_ns.values): + try: + values = param_ns.values() + if isinstance(values, dict): + return _sanitize(values) + except Exception: + pass + + return None + + +def settings_repr_value(value: object) -> dict[str, Any] | str: + structured = settings_structured_value(value) + if structured is not None: + return structured + return repr(value) + + +def _widget_hint( + *, + field_type: str, + choices: list[Any] | None, + bounds: tuple[float | None, float | None] | None, +) -> str | None: + field_type_lower = field_type.lower() + if choices: + return "select" + if "bool" in field_type_lower: + return "checkbox" + if bounds is not None and ("int" in field_type_lower or "float" in field_type_lower): + return "slider" + if "int" in field_type_lower: + return "int_input" + if "float" in field_type_lower: + return "float_input" + if "str" in field_type_lower: + return "text_input" + return None + + +def _choices_from_annotation(annotation: Any) -> list[Any] | None: + origin = get_origin(annotation) + if origin is None: + return None + origin_name = getattr(origin, "__name__", str(origin)) + if origin_name != "Literal": + return None + return [_sanitize(val) for val in get_args(annotation)] + + +def _extract_bounds(obj: object) -> tuple[float | None, float | None] | None: + lower = None + upper = None + for attr in ("ge", "gt", "min_length"): + if hasattr(obj, attr): + bound_val = getattr(obj, attr) + if isinstance(bound_val, (int, float)): + lower = float(bound_val) + break + for attr in ("le", "lt", "max_length"): + if hasattr(obj, attr): + bound_val = getattr(obj, attr) + if isinstance(bound_val, (int, float)): + upper = float(bound_val) + break + if lower is None and upper is None: + return None + return (lower, upper) + + +def settings_schema_from_type(settings_type: object) -> SettingsSchemaMetadata | None: + if not isinstance(settings_type, type): + return None + + if is_dataclass(settings_type): + fields: list[SettingsFieldMetadata] = [] + for f in dataclass_fields(settings_type): + required = f.default is MISSING and f.default_factory is MISSING + default_val: Any | None = None + if not required: + if f.default is not MISSING: + default_val = _sanitize(f.default) + elif f.default_factory is not MISSING: + try: + default_val = _sanitize(f.default_factory()) + except Exception: + default_val = "" + metadata = f.metadata if isinstance(f.metadata, Mapping) else {} + description = metadata.get("description") + choices = metadata.get("choices") + if isinstance(choices, (list, tuple, set)): + choices = [_sanitize(val) for val in choices] + else: + choices = _choices_from_annotation(f.type) + bounds = None + ge = metadata.get("ge", metadata.get("min")) + le = metadata.get("le", metadata.get("max")) + if isinstance(ge, (int, float)) or isinstance(le, (int, float)): + bounds = ( + float(ge) if isinstance(ge, (int, float)) else None, + float(le) if isinstance(le, (int, float)) else None, + ) + field_type = _type_name(f.type) + fields.append( + SettingsFieldMetadata( + name=f.name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices if isinstance(choices, list) else None, + widget_hint=_widget_hint( + field_type=field_type, + choices=choices if isinstance(choices, list) else None, + bounds=bounds, + ), + ) + ) + return SettingsSchemaMetadata( + provider="dataclass", + settings_type=_type_name(settings_type), + fields=fields, + ) + + if hasattr(settings_type, "model_fields"): + model_fields = getattr(settings_type, "model_fields") + if isinstance(model_fields, dict): + fields: list[SettingsFieldMetadata] = [] + for name, field_info in model_fields.items(): + annotation = getattr(field_info, "annotation", Any) + is_required_attr = getattr(field_info, "is_required", None) + required = ( + bool(is_required_attr()) + if callable(is_required_attr) + else bool(is_required_attr) + ) + default_val = None + if not required: + default = getattr(field_info, "default", None) + default_val = _sanitize(default) + description = getattr(field_info, "description", None) + choices = _choices_from_annotation(annotation) + bounds = _extract_bounds(field_info) + field_type = _type_name(annotation) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="pydantic", + settings_type=_type_name(settings_type), + fields=fields, + ) + + if hasattr(settings_type, "__fields__"): + model_fields = getattr(settings_type, "__fields__") + if isinstance(model_fields, dict): + fields: list[SettingsFieldMetadata] = [] + for name, field_info in model_fields.items(): + annotation = getattr(field_info, "outer_type_", Any) + required = bool(getattr(field_info, "required", False)) + default_val = None if required else _sanitize(getattr(field_info, "default", None)) + fi = getattr(field_info, "field_info", None) + description = getattr(fi, "description", None) if fi is not None else None + choices = _choices_from_annotation(annotation) + bounds = _extract_bounds(fi if fi is not None else field_info) + field_type = _type_name(annotation) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=required, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="pydantic", + settings_type=_type_name(settings_type), + fields=fields, + ) + + param_ns = getattr(settings_type, "param", None) + if param_ns is not None and hasattr(param_ns, "objects"): + try: + objects = param_ns.objects("existing") + except Exception: + try: + objects = param_ns.objects() + except Exception: + objects = None + if isinstance(objects, dict): + fields: list[SettingsFieldMetadata] = [] + for name, param_obj in objects.items(): + if name == "name": + continue + choices_obj = getattr(param_obj, "objects", None) + choices = None + if isinstance(choices_obj, Mapping): + choices = [_sanitize(choice) for choice in choices_obj.keys()] + elif isinstance(choices_obj, (list, tuple, set)): + choices = [_sanitize(choice) for choice in choices_obj] + bounds_obj = getattr(param_obj, "bounds", None) + bounds = None + if ( + isinstance(bounds_obj, tuple) + and len(bounds_obj) == 2 + and all( + bound is None or isinstance(bound, (int, float)) + for bound in bounds_obj + ) + ): + bounds = ( + float(bounds_obj[0]) if isinstance(bounds_obj[0], (int, float)) else None, + float(bounds_obj[1]) if isinstance(bounds_obj[1], (int, float)) else None, + ) + default_val = _sanitize(getattr(param_obj, "default", None)) + description = getattr(param_obj, "doc", None) + field_type = _type_name(type(param_obj)) + fields.append( + SettingsFieldMetadata( + name=name, + field_type=field_type, + required=False, + default=default_val, + description=description if isinstance(description, str) else None, + bounds=bounds, + choices=choices, + widget_hint=_widget_hint( + field_type=field_type, choices=choices, bounds=bounds + ), + ) + ) + return SettingsSchemaMetadata( + provider="param", + settings_type=_type_name(settings_type), + fields=fields, + ) + + return None + + +def settings_schema_from_value(value: object) -> SettingsSchemaMetadata | None: + if value is None: + return None + return settings_schema_from_type(type(value)) + diff --git a/src/ezmsg/core/subclient.py b/src/ezmsg/core/subclient.py index 3ca2dc2..4b9f871 100644 --- a/src/ezmsg/core/subclient.py +++ b/src/ezmsg/core/subclient.py @@ -9,6 +9,8 @@ from .graphserver import GraphService from .channelmanager import CHANNELS from .messagechannel import NotificationQueue, LeakyQueue, Channel +from .profiling import PROFILES, PROFILE_TIME +from .graphmeta import ProfileChannelType from .netprotocol import ( AddressType, @@ -128,6 +130,7 @@ def __init__( self._graph_address = graph_address self._channels = dict() + self._active_msg_seq: int | None = None if self.leaky: self._incoming = LeakyQueue( 1 if max_queue is None else max_queue, self._handle_dropped_notification @@ -135,6 +138,7 @@ def __init__( else: self._incoming = asyncio.Queue() self._initialized = asyncio.Event() + PROFILES.register_subscriber(self.id, self.topic) def _handle_dropped_notification( self, notification: typing.Tuple[UUID, int] @@ -160,6 +164,7 @@ def close(self) -> None: and closes all shared memory contexts. """ self._graph_task.cancel() + PROFILES.unregister_subscriber(self.id) async def wait_closed(self) -> None: """ @@ -295,5 +300,29 @@ async def recv_zero_copy(self) -> typing.AsyncGenerator[typing.Any, None]: break # Stale notification from an unregistered publisher — skip. - with self._channels[pub_id].get(msg_id, self.id) as msg: - yield msg + channel = self._channels[pub_id] + channel_kind = getattr(channel, "channel_kind", ProfileChannelType.UNKNOWN) + self._active_msg_seq = msg_id + try: + start_ns = PROFILE_TIME() + with channel.get(msg_id, self.id) as msg: + yield msg + end_ns = PROFILE_TIME() + PROFILES.subscriber_receive( + self.id, end_ns, end_ns - start_ns, channel_kind, msg_seq=msg_id + ) + finally: + self._active_msg_seq = None + + def begin_profile(self) -> int: + return PROFILE_TIME() + + def end_profile(self, start_ns: int, label: str | None = None) -> None: + end_ns = PROFILE_TIME() + PROFILES.subscriber_user_span( + self.id, + end_ns, + end_ns - start_ns, + label, + msg_seq=self._active_msg_seq, + ) diff --git a/tests/shutdown_runner.py b/tests/shutdown_runner.py index c185a23..39bb512 100644 --- a/tests/shutdown_runner.py +++ b/tests/shutdown_runner.py @@ -8,11 +8,14 @@ import ezmsg.core as ez +STARTED = threading.Event() + class BlockingDiskIO(ez.Unit): @ez.task async def blocked_read(self) -> None: # Cross-platform "hung disk I/O" simulation. + STARTED.set() event = threading.Event() self._event = event await asyncio.shield(asyncio.to_thread(event.wait)) @@ -21,6 +24,7 @@ async def blocked_read(self) -> None: class BlockingSocket(ez.Unit): @ez.task async def blocked_recv(self) -> None: + STARTED.set() sock_r, sock_w = socket.socketpair() sock_r.setblocking(True) sock_w.setblocking(True) @@ -33,6 +37,7 @@ async def blocked_recv(self) -> None: class ExplodeOnCancel(ez.Unit): @ez.task async def explode(self) -> None: + STARTED.set() try: while True: await asyncio.sleep(1.0) @@ -43,6 +48,7 @@ async def explode(self) -> None: class StubbornTask(ez.Unit): @ez.task async def ignore_cancel(self) -> None: + STARTED.set() while True: try: await asyncio.sleep(1.0) @@ -84,7 +90,7 @@ def _emit_ready() -> None: def _watch_ready() -> None: while not done.is_set(): - if runner.running: + if runner.running and STARTED.is_set(): _emit_ready() return time.sleep(0.01) diff --git a/tests/test_process_control.py b/tests/test_process_control.py new file mode 100644 index 0000000..81a22df --- /dev/null +++ b/tests/test_process_control.py @@ -0,0 +1,186 @@ +import asyncio +import pickle + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ProcessRegistration +from ezmsg.core.processclient import ProcessControlClient +from ezmsg.core.graphserver import GraphService +from ezmsg.core.netprotocol import Command, close_stream_writer, read_str, uint64_to_bytes + + +@pytest.mark.asyncio +async def test_process_registration_visible_in_snapshot(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + + try: + await process.register(["SYS/U1", "SYS/U2"]) + await process.update_ownership(added_units=["SYS/U3"], removed_units=["SYS/U1"]) + + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 1 + + process_entry = next(iter(snapshot.processes.values())) + assert process_entry.process_id == process_key + assert process_entry.pid is not None + assert process_entry.host is not None + assert process_entry.units == ["SYS/U2", "SYS/U3"] + + finally: + await process.close() + await asyncio.sleep(0.05) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_snapshot_entry_drops_on_disconnect(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + + try: + await process.register(["SYS/U1"]) + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 1 + + await process.close() + await asyncio.sleep(0.05) + + snapshot = await observer.snapshot() + assert len(snapshot.processes) == 0 + + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "command", + [ + Command.PROCESS_REGISTER, + Command.PROCESS_UPDATE_OWNERSHIP, + Command.PROCESS_SETTINGS_UPDATE, + ], +) +async def test_process_payload_parse_failures_return_error_ack(command: Command): + graph_server = GraphService().create_server() + address = graph_server.address + + reader, writer = await GraphService(address).open_connection() + try: + writer.write(Command.PROCESS.value) + await writer.drain() + _client_id = await read_str(reader) + response = await reader.read(1) + assert response == Command.COMPLETE.value + + # Non-pickled bytes intentionally trigger parse failure in process handlers. + bad_payload = b"not-a-pickle-payload" + writer.write(command.value) + writer.write(uint64_to_bytes(len(bad_payload))) + writer.write(bad_payload) + await writer.drain() + + response = await reader.read(1) + assert response == Command.ERROR.value + finally: + await close_stream_writer(writer) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_register_succeeds_after_error_ack(): + graph_server = GraphService().create_server() + address = graph_server.address + + reader, writer = await GraphService(address).open_connection() + try: + writer.write(Command.PROCESS.value) + await writer.drain() + _client_id = await read_str(reader) + response = await reader.read(1) + assert response == Command.COMPLETE.value + + bad_payload = b"not-a-pickle-payload" + writer.write(Command.PROCESS_REGISTER.value) + writer.write(uint64_to_bytes(len(bad_payload))) + writer.write(bad_payload) + await writer.drain() + response = await reader.read(1) + assert response == Command.ERROR.value + + good_payload = pickle.dumps( + ProcessRegistration( + pid=123, + host="test-host", + units=["SYS/U1"], + ) + ) + writer.write(Command.PROCESS_REGISTER.value) + writer.write(uint64_to_bytes(len(good_payload))) + writer.write(good_payload) + await writer.drain() + response = await reader.read(1) + assert response == Command.COMPLETE.value + finally: + await close_stream_writer(writer) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_register_rejects_unit_ownership_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + process_a = ProcessControlClient(address) + process_b = ProcessControlClient(address) + await process_a.connect() + await process_b.connect() + + try: + await process_a.register(["SYS/U1"]) + with pytest.raises(RuntimeError, match="PROCESS_REGISTER"): + await process_b.register(["SYS/U1"]) + finally: + await process_a.close() + await process_b.close() + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_update_ownership_rejects_unit_ownership_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + process_a = ProcessControlClient(address) + process_b = ProcessControlClient(address) + await process_a.connect() + await process_b.connect() + + try: + await process_a.register(["SYS/U1"]) + await process_b.register(["SYS/U2"]) + with pytest.raises(RuntimeError, match="PROCESS_UPDATE_OWNERSHIP"): + await process_b.update_ownership(added_units=["SYS/U1"]) + finally: + await process_a.close() + await process_b.close() + graph_server.stop() diff --git a/tests/test_process_routing.py b/tests/test_process_routing.py new file mode 100644 index 0000000..2402466 --- /dev/null +++ b/tests/test_process_routing.py @@ -0,0 +1,150 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ProcessControlErrorCode, + ProcessControlRequest, + ProcessControlResponse, +) +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +@pytest.mark.asyncio +async def test_process_routing_roundtrip(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U1"]) + + async def handler(request: ProcessControlRequest) -> ProcessControlResponse: + assert request.unit_address == "SYS/U1" + assert request.operation == "ECHO" + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=request.payload, + ) + + process.set_request_handler(handler) + + try: + response = await observer.process_request( + "SYS/U1", + "ECHO", + payload=b"hello", + timeout=1.0, + ) + assert response.ok + assert response.payload == b"hello" + assert response.process_id == process_key + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_routing_builtin_ping_and_stats(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U1", "SYS/U2"]) + await process.update_ownership(removed_units=["SYS/U2"], added_units=["SYS/U3"]) + + try: + ping = await observer.process_ping("SYS/U1", timeout=1.0) + assert ping.process_id == process_key + assert ping.pid > 0 + assert ping.host + assert ping.timestamp > 0.0 + + stats = await observer.process_stats("SYS/U1", timeout=1.0) + assert stats.process_id == process_key + assert stats.pid > 0 + assert stats.host + assert stats.owned_units == ["SYS/U1", "SYS/U3"] + assert stats.timestamp > 0.0 + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_routing_missing_owner_returns_error(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + try: + response = await observer.process_request( + "SYS/UNKNOWN", + "PING", + payload=b"", + timeout=0.25, + ) + assert not response.ok + assert response.error is not None + assert "No process owns unit" in response.error + assert response.error_code == ProcessControlErrorCode.UNROUTABLE_UNIT + finally: + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_routing_timeout_returns_error(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U2"]) + + block = asyncio.Event() + + async def blocking_handler(_request: ProcessControlRequest) -> ProcessControlResponse: + await block.wait() + return ProcessControlResponse(request_id="", ok=False) + + process.set_request_handler(blocking_handler) + + try: + response = await observer.process_request( + "SYS/U2", + "SLOW", + timeout=0.05, + ) + assert not response.ok + assert response.error is not None + assert "Timed out waiting for process response" in response.error + assert response.error_code == ProcessControlErrorCode.TIMEOUT + assert response.process_id == process_key + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() diff --git a/tests/test_profiling_api.py b/tests/test_profiling_api.py new file mode 100644 index 0000000..2b07dd0 --- /dev/null +++ b/tests/test_profiling_api.py @@ -0,0 +1,539 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ProcessControlErrorCode, + ProfilingStreamControl, + ProfilingTraceControl, +) +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +@pytest.mark.asyncio +async def test_process_profiling_snapshot_collects_pub_sub_metrics(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U1"]) + + pub = await ctx.publisher("TOPIC_PROF") + sub = await ctx.subscriber("TOPIC_PROF") + + try: + for idx in range(8): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + snap = await ctx.process_profiling_snapshot("SYS/U1", timeout=1.0) + assert snap.process_id == process_key + assert snap.window_seconds > 0 + assert len(snap.publishers) >= 1 + assert len(snap.subscribers) >= 1 + + pub_metrics = next( + pub for pub in snap.publishers.values() if pub.topic == "TOPIC_PROF" + ) + assert pub_metrics.messages_published_total >= 8 + assert pub_metrics.publish_rate_hz_window >= 0.0 + + sub_metrics = next( + sub for sub in snap.subscribers.values() if sub.topic == "TOPIC_PROF" + ) + assert sub_metrics.messages_received_total >= 8 + assert sub_metrics.lease_time_ns_total > 0 + assert sub_metrics.lease_time_ns_avg_window >= 0.0 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_connect_does_not_clear_preexisting_profile_metrics(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + pub = await ctx.publisher("TOPIC_PRECONNECT") + sub = await ctx.subscriber("TOPIC_PRECONNECT") + for idx in range(6): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U_PRE"]) + + try: + snap = await ctx.process_profiling_snapshot("SYS/U_PRE", timeout=1.0) + assert len(snap.publishers) >= 1 + assert len(snap.subscribers) >= 1 + pub_metrics = next( + pub for pub in snap.publishers.values() if pub.topic == "TOPIC_PRECONNECT" + ) + sub_metrics = next( + sub for sub in snap.subscribers.values() if sub.topic == "TOPIC_PRECONNECT" + ) + assert pub_metrics.messages_published_total >= 6 + assert sub_metrics.messages_received_total >= 6 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_and_batch(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U2"]) + + pub = await ctx.publisher("TOPIC_TRACE") + sub = await ctx.subscriber("TOPIC_TRACE") + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U2", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE"], + subscriber_topics=["TOPIC_TRACE"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(5): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch = await ctx.process_profiling_trace_batch( + "SYS/U2", max_samples=200, timeout=1.0 + ) + assert batch.process_id == process_key + assert len(batch.samples) > 0 + + disable_response = await ctx.process_set_profiling_trace( + "SYS/U2", + ProfilingTraceControl(enabled=False), + timeout=1.0, + ) + assert disable_response.ok + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_profiling_snapshot_all_and_unroutable_error_code(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U3"]) + + try: + snapshots = await ctx.profiling_snapshot_all(timeout_per_process=0.5) + assert process_key in snapshots + assert snapshots[process_key].process_id == process_key + + response = await ctx.process_request( + "SYS/MISSING", + "GET_PROFILING_SNAPSHOT", + timeout=0.2, + ) + assert not response.ok + assert response.error_code == ProcessControlErrorCode.UNROUTABLE_UNIT + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_push(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U4"]) + + pub = await ctx.publisher("TOPIC_STREAM") + sub = await ctx.subscriber("TOPIC_STREAM") + stream = None + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U4", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_STREAM"], + subscriber_topics=["TOPIC_STREAM"], + ), + timeout=1.0, + ) + assert response.ok + + stream = ctx.subscribe_profiling_trace( + ProfilingStreamControl(interval=0.02, max_samples=256) + ) + + for idx in range(8): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch = await asyncio.wait_for(anext(stream), timeout=1.0) + assert batch.timestamp > 0.0 + assert process_key in batch.batches + process_batch = batch.batches[process_key] + assert process_batch.process_id == process_key + assert len(process_batch.samples) > 0 + finally: + if stream is not None: + await stream.aclose() + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_endpoint_metric_and_ttl(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U5"]) + + pub_a = await ctx.publisher("TOPIC_A") + sub_a = await ctx.subscriber("TOPIC_A") + pub_b = await ctx.publisher("TOPIC_B") + sub_b = await ctx.subscriber("TOPIC_B") + + try: + # Warm up and discover endpoint IDs for precise filter targeting. + for idx in range(3): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + await pub_b.broadcast(idx) + async with sub_b.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + snapshot = await ctx.process_profiling_snapshot("SYS/U5", timeout=1.0) + pub_a_endpoint = next( + pub.endpoint_id + for pub in snapshot.publishers.values() + if pub.topic == "TOPIC_A" + ) + + response = await ctx.process_set_profiling_trace( + "SYS/U5", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_endpoint_ids=[pub_a_endpoint], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(8): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + await pub_b.broadcast(idx) + async with sub_b.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U5", max_samples=512, timeout=1.0 + ) + assert len(batch.samples) > 0 + assert all(sample.metric == "publish_delta_ns" for sample in batch.samples) + assert all(sample.endpoint_id == pub_a_endpoint for sample in batch.samples) + + ttl_response = await ctx.process_set_profiling_trace( + "SYS/U5", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_endpoint_ids=[pub_a_endpoint], + metrics=["publish_delta_ns"], + ttl_seconds=0.01, + ), + timeout=1.0, + ) + assert ttl_response.ok + await asyncio.sleep(0.03) + + for idx in range(3): + await pub_a.broadcast(idx) + async with sub_a.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + expired_batch = await ctx.process_profiling_trace_batch( + "SYS/U5", max_samples=512, timeout=1.0 + ) + assert len(expired_batch.samples) == 0 + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_stream_control(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process_a = ProcessControlClient(address) + await process_a.connect() + assert process_a.client_id is not None + process_a_key = process_a.client_id + await process_a.register(["SYS/U6"]) + + stream = None + try: + stream = ctx.subscribe_profiling_trace( + ProfilingStreamControl( + interval=0.02, + max_samples=64, + process_ids=[process_a_key], + include_empty_batches=True, + ) + ) + batch = await asyncio.wait_for(anext(stream), timeout=1.0) + assert process_a_key in batch.batches + assert len(batch.batches) == 1 + finally: + if stream is not None: + await stream.aclose() + await process_a.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_subscription_does_not_starve_peer_subscribers(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U7"]) + + pub = await ctx.publisher("TOPIC_STREAM_MULTI") + sub = await ctx.subscriber("TOPIC_STREAM_MULTI") + stream_a = None + stream_b = None + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U7", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_STREAM_MULTI"], + subscriber_topics=["TOPIC_STREAM_MULTI"], + ), + timeout=1.0, + ) + assert response.ok + + control = ProfilingStreamControl(interval=0.02, max_samples=256) + stream_a = ctx.subscribe_profiling_trace(control) + stream_b = ctx.subscribe_profiling_trace(control) + + for idx in range(12): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + span_start_ns = sub.begin_profile() + try: + await asyncio.sleep(0) + finally: + sub.end_profile(span_start_ns, "taskA") + + batch_a = await asyncio.wait_for(anext(stream_a), timeout=1.0) + batch_b = await asyncio.wait_for(anext(stream_b), timeout=1.0) + + assert process_key in batch_a.batches + assert process_key in batch_b.batches + assert len(batch_a.batches[process_key].samples) > 0 + assert len(batch_b.batches[process_key].samples) > 0 + finally: + if stream_a is not None: + await stream_a.aclose() + if stream_b is not None: + await stream_b.aclose() + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_batch_interleaves_publisher_and_subscriber_samples(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U8"]) + + pub = await ctx.publisher("TOPIC_TRACE_MIX") + sub = await ctx.subscriber("TOPIC_TRACE_MIX") + + try: + response = await ctx.process_set_profiling_trace( + "SYS/U8", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_MIX"], + subscriber_topics=["TOPIC_TRACE_MIX"], + metrics=["publish_delta_ns", "lease_time_ns"], + ), + timeout=1.0, + ) + assert response.ok + + for idx in range(64): + await pub.broadcast(idx) + async with sub.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U8", max_samples=32, timeout=1.0 + ) + metrics = {sample.metric for sample in batch.samples} + assert "publish_delta_ns" in metrics + assert "lease_time_ns" in metrics + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_profiling_trace_control_change_clears_stale_trace_samples(): + graph_server = GraphService().create_server() + address = graph_server.address + + ctx = GraphContext(address, auto_start=False) + await ctx.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/U9"]) + + pub_old = await ctx.publisher("TOPIC_TRACE_OLD") + sub_old = await ctx.subscriber("TOPIC_TRACE_OLD") + pub_new = await ctx.publisher("TOPIC_TRACE_NEW") + sub_new = await ctx.subscriber("TOPIC_TRACE_NEW") + + try: + old_response = await ctx.process_set_profiling_trace( + "SYS/U9", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_OLD"], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert old_response.ok + + for idx in range(12): + await pub_old.broadcast(idx) + async with sub_old.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + new_response = await ctx.process_set_profiling_trace( + "SYS/U9", + ProfilingTraceControl( + enabled=True, + sample_mod=1, + publisher_topics=["TOPIC_TRACE_NEW"], + metrics=["publish_delta_ns"], + ), + timeout=1.0, + ) + assert new_response.ok + + for idx in range(8): + await pub_new.broadcast(idx) + async with sub_new.recv_zero_copy() as _msg: + await asyncio.sleep(0) + + batch = await ctx.process_profiling_trace_batch( + "SYS/U9", max_samples=256, timeout=1.0 + ) + assert len(batch.samples) > 0 + assert all(sample.topic == "TOPIC_TRACE_NEW" for sample in batch.samples) + finally: + await process.close() + await ctx.__aexit__(None, None, None) + graph_server.stop() diff --git a/tests/test_settings_api.py b/tests/test_settings_api.py new file mode 100644 index 0000000..effbe08 --- /dev/null +++ b/tests/test_settings_api.py @@ -0,0 +1,468 @@ +import asyncio +import pickle +import time +from dataclasses import dataclass + +import pytest + +import ezmsg.core as ez +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import ( + ComponentMetadata, + DynamicSettingsMetadata, + GraphMetadata, + ProcessControlErrorCode, + ProcessControlResponse, + SettingsFieldUpdateRequest, + SettingsEventType, + SettingsSnapshotValue, +) +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +def _metadata_with_component(component_address: str) -> GraphMetadata: + return GraphMetadata( + schema_version=1, + root_name="SYS", + components={ + component_address: ComponentMetadata( + address=component_address, + name="UNIT", + component_type="example.Unit", + settings_type="example.Settings", + initial_settings=(None, {"alpha": 1}), + dynamic_settings=DynamicSettingsMetadata( + enabled=True, + input_topic=f"{component_address}/INPUT_SETTINGS", + settings_type="example.Settings", + ), + settings_schema=None, + ) + }, + ) + + +@pytest.mark.asyncio +async def test_settings_snapshot_and_events_from_metadata_registration(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + component_address = "SYS/UNIT_A" + await owner.register_metadata(_metadata_with_component(component_address)) + + settings = await observer.settings_snapshot() + assert component_address in settings + assert settings[component_address].repr_value == {"alpha": 1} + assert settings[component_address].structured_value == {"alpha": 1} + assert settings[component_address].settings_schema is None + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == component_address + and event.event_type == SettingsEventType.INITIAL_SETTINGS + ] + assert matching + assert matching[-1].source_session_id == str(owner._session_id) + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@dataclass +class _SettingsMsg: + gain: int + + +class _SettingsSource(ez.Unit): + OUTPUT = ez.OutputStream(_SettingsMsg) + + @ez.publisher(OUTPUT) + async def emit(self): + yield self.OUTPUT, _SettingsMsg(gain=7) + raise ez.Complete + + +class _SettingsSink(ez.Unit): + INPUT_SETTINGS = ez.InputStream(_SettingsMsg) + + @ez.subscriber(INPUT_SETTINGS) + async def on_settings(self, msg: _SettingsMsg) -> None: + raise ez.NormalTermination + + +class _SettingsSystem(ez.Collection): + SRC = _SettingsSource() + SINK = _SettingsSink() + + def network(self) -> ez.NetworkDefinition: + return ((self.SRC.OUTPUT, self.SINK.INPUT_SETTINGS),) + + +class _SettingsOnlySystem(ez.Collection): + SINK = _SettingsSink() + + def network(self) -> ez.NetworkDefinition: + return () + + +def test_input_settings_hook_reports_to_graphserver(): + graph_server = GraphService().create_server() + address = graph_server.address + try: + ez.run(components={"SYS": _SettingsSystem()}, graph_address=address, force_single_process=True) + + async def observe() -> None: + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + try: + settings = await observer.settings_snapshot() + sink_address = "SYS/SINK" + # Process-owned settings are cleaned up when the process exits. + assert sink_address not in settings + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == sink_address + and event.event_type == SettingsEventType.SETTINGS_UPDATED + and event.value.repr_value == {"gain": 7} + ] + assert matching + latest = matching[-1].value + assert latest.structured_value == {"gain": 7} + assert latest.settings_schema is not None + schema = latest.settings_schema + assert schema.provider == "dataclass" + assert any( + field.name == "gain" and "int" in field.field_type.lower() + for field in schema.fields + ) + finally: + await observer.__aexit__(None, None, None) + + asyncio.run(observe()) + finally: + graph_server.stop() + + +@pytest.mark.asyncio +async def test_graphcontext_update_settings_via_input_settings_topic(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + run_task = asyncio.create_task( + asyncio.to_thread( + ez.run, + components={"SYS": _SettingsOnlySystem()}, + graph_address=address, + force_single_process=True, + ) + ) + + try: + for _ in range(40): + try: + await observer.settings_input_topic("SYS/SINK") + break + except RuntimeError: + await asyncio.sleep(0.05) + else: + raise AssertionError("Timed out waiting for dynamic settings metadata") + + await observer.update_settings("SYS/SINK", _SettingsMsg(gain=11)) + await asyncio.wait_for(run_task, timeout=5.0) + + settings = await observer.settings_snapshot() + assert "SYS/SINK" not in settings + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == "SYS/SINK" + and event.event_type == SettingsEventType.SETTINGS_UPDATED + and event.value.repr_value == {"gain": 11} + ] + assert matching + + finally: + if not run_task.done(): + await asyncio.wait_for(run_task, timeout=5.0) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_graphcontext_update_setting_field_routes_to_process(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/SINK"]) + + try: + async def handler(request) -> ProcessControlResponse: + assert request.operation == "UPDATE_SETTING_FIELD" + assert request.payload is not None + update = pickle.loads(request.payload) + assert isinstance(update, SettingsFieldUpdateRequest) + assert update.field_path == "gain" + assert update.value == 11 + return ProcessControlResponse( + request_id=request.request_id, + ok=True, + payload=pickle.dumps( + SettingsSnapshotValue(serialized=None, repr_value={"gain": 11}) + ), + ) + + process.set_request_handler(handler) + + patched = await observer.update_setting("SYS/SINK", "gain", 11) + assert patched.repr_value == {"gain": 11} + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_graphcontext_update_setting_waits_and_propagates_process_failure(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/SINK"]) + + try: + async def handler(request) -> ProcessControlResponse: + assert request.operation == "UPDATE_SETTING_FIELD" + await asyncio.sleep(0.05) + return ProcessControlResponse( + request_id=request.request_id, + ok=False, + error="Simulated publish failure", + error_code=ProcessControlErrorCode.HANDLER_ERROR, + ) + + process.set_request_handler(handler) + + start = time.perf_counter() + with pytest.raises(RuntimeError, match="Simulated publish failure"): + await observer.update_setting("SYS/SINK", "gain", 99, timeout=1.0) + elapsed = time.perf_counter() - start + assert elapsed >= 0.04 + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_reported_settings_update_visible_in_snapshot_and_events(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + assert process.client_id is not None + process_key = process.client_id + + try: + await process.register(["SYS/UNIT_B"]) + await process.report_settings_update( + component_address="SYS/UNIT_B", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 2}), + ) + + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_B"].repr_value == {"gain": 2} + + events = await observer.settings_events(after_seq=0) + matching = [ + event + for event in events + if event.component_address == "SYS/UNIT_B" + and event.event_type == SettingsEventType.SETTINGS_UPDATED + ] + assert matching + assert matching[-1].source_process_id == process_key + + stream = observer.subscribe_settings_events(after_seq=0) + streamed = await asyncio.wait_for(anext(stream), timeout=1.0) + assert streamed.component_address == "SYS/UNIT_B" + await stream.aclose() + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_session_owned_settings_removed_when_session_drops(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + try: + component_address = "SYS/UNIT_C" + await owner.register_metadata(_metadata_with_component(component_address)) + settings = await observer.settings_snapshot() + assert component_address in settings + + await owner._close_session() + await asyncio.sleep(0.05) + + settings = await observer.settings_snapshot() + assert component_address not in settings + finally: + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_metadata_registration_rejects_component_address_collision(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner_a = GraphContext(address, auto_start=False) + owner_b = GraphContext(address, auto_start=False) + + await owner_a.__aenter__() + await owner_b.__aenter__() + + try: + component_address = "SYS/UNIT_COLLIDE" + metadata = _metadata_with_component(component_address) + await owner_a.register_metadata(metadata) + with pytest.raises( + RuntimeError, + match="component address collision\\(s\\): SYS/UNIT_COLLIDE", + ): + await owner_b.register_metadata(metadata) + finally: + await owner_a.__aexit__(None, None, None) + await owner_b.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_owned_settings_removed_when_process_disconnects_without_session_owner(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_ORPHAN"]) + + try: + await process.report_settings_update( + component_address="SYS/UNIT_ORPHAN", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 5}), + ) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_ORPHAN"].repr_value == {"gain": 5} + finally: + await process.close() + + await asyncio.sleep(0.05) + settings = await observer.settings_snapshot() + assert "SYS/UNIT_ORPHAN" not in settings + + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_disconnect_restores_metadata_initial_settings(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + await owner.__aenter__() + await observer.__aenter__() + await owner.register_metadata(_metadata_with_component("SYS/UNIT_RESTORE")) + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_RESTORE"]) + + try: + await process.report_settings_update( + component_address="SYS/UNIT_RESTORE", + value=SettingsSnapshotValue(serialized=None, repr_value={"alpha": 9}), + ) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_RESTORE"].repr_value == {"alpha": 9} + finally: + await process.close() + + await asyncio.sleep(0.05) + settings = await observer.settings_snapshot() + assert settings["SYS/UNIT_RESTORE"].repr_value == {"alpha": 1} + + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_process_settings_update_rejected_for_unowned_component(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + await process.connect() + await process.register(["SYS/UNIT_OWNED"]) + + try: + with pytest.raises(RuntimeError, match="Process control command failed"): + await process.report_settings_update( + component_address="SYS/UNIT_UNOWNED", + value=SettingsSnapshotValue(serialized=None, repr_value={"gain": 7}), + ) + settings = await observer.settings_snapshot() + assert "SYS/UNIT_UNOWNED" not in settings + finally: + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop() diff --git a/tests/test_topology_api.py b/tests/test_topology_api.py new file mode 100644 index 0000000..4633c09 --- /dev/null +++ b/tests/test_topology_api.py @@ -0,0 +1,105 @@ +import asyncio + +import pytest + +from ezmsg.core.graphcontext import GraphContext +from ezmsg.core.graphmeta import TopologyChangedEvent, TopologyEventType +from ezmsg.core.graphserver import GraphService +from ezmsg.core.processclient import ProcessControlClient + + +async def _next_matching_event( + stream, predicate, timeout: float = 1.0 +) -> TopologyChangedEvent: + async def _wait() -> TopologyChangedEvent: + while True: + event = await anext(stream) + if predicate(event): + return event + + return await asyncio.wait_for(_wait(), timeout=timeout) + + +@pytest.mark.asyncio +async def test_topology_subscription_reports_session_edge_changes(): + graph_server = GraphService().create_server() + address = graph_server.address + + owner = GraphContext(address, auto_start=False) + observer = GraphContext(address, auto_start=False) + + await owner.__aenter__() + await observer.__aenter__() + + stream = observer.subscribe_topology_events(after_seq=0) + try: + await owner.connect("SRC", "DST") + event = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.GRAPH_CHANGED + and "DST" in e.changed_topics + ), + timeout=1.0, + ) + assert event.source_session_id == str(owner._session_id) + + await owner.disconnect("SRC", "DST") + event = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.GRAPH_CHANGED + and "DST" in e.changed_topics + ), + timeout=1.0, + ) + assert event.source_session_id == str(owner._session_id) + finally: + await stream.aclose() + await owner.__aexit__(None, None, None) + await observer.__aexit__(None, None, None) + graph_server.stop() + + +@pytest.mark.asyncio +async def test_topology_subscription_reports_process_changes(): + graph_server = GraphService().create_server() + address = graph_server.address + + observer = GraphContext(address, auto_start=False) + await observer.__aenter__() + + process = ProcessControlClient(address) + stream = observer.subscribe_topology_events(after_seq=0) + + try: + await process.connect() + assert process.client_id is not None + process_key = process.client_id + await process.register(["SYS/U1"]) + + registered = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.PROCESS_CHANGED + and e.source_process_id == process_key + ), + timeout=1.0, + ) + assert registered.source_session_id is None + + await process.update_ownership(added_units=["SYS/U2"], removed_units=["SYS/U1"]) + updated = await _next_matching_event( + stream, + lambda e: ( + e.event_type == TopologyEventType.PROCESS_CHANGED + and e.source_process_id == process_key + ), + timeout=1.0, + ) + assert updated.source_session_id is None + finally: + await stream.aclose() + await process.close() + await observer.__aexit__(None, None, None) + graph_server.stop()