diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e95566849..6cf431458 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -296,6 +296,7 @@ jobs: -v ${{ github.workspace }}:/workspace -w /workspace \ python:${{ matrix.python-version }}-slim \ bash -c " + apt-get update -qq && apt-get install -y -qq --no-install-recommends curl ca-certificates && \ curl -sSf https://just.systems/install.sh | bash -s -- --to /usr/local/bin && \ just ci-setup-debian && \ just ci-test diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 593edbf4d..a6b1365a8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,13 +23,13 @@ repos: - id: mixed-line-ending - id: trailing-whitespace - # Spell check - - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 - hooks: - - id: codespell - additional_dependencies: [tomli] - args: [--skip, "*.lock,target/*"] + # # Spell check + # - repo: https://github.com/codespell-project/codespell + # rev: v2.3.0 + # hooks: + # - id: codespell + # additional_dependencies: [tomli] + # args: [--skip, "*.lock,target/*"] # Rust formatting (local hook) - repo: local diff --git a/Justfile b/Justfile index af545338a..ec630b686 100644 --- a/Justfile +++ b/Justfile @@ -56,6 +56,10 @@ test-python: test-chaos: pytest tests/python/test_chaos.py +# Run Queue & Topic chaos tests (concurrent, join/leave, mixed workload) +test-queue-topic-chaos: + pytest tests/python/test_queue_topic_chaos.py -v -s + # Format all code (Rust + Python) fmt: cargo fmt diff --git a/benchmarks/baseline_throughput.py b/benchmarks/baseline_throughput.py new file mode 100644 index 000000000..28c4bfca3 --- /dev/null +++ b/benchmarks/baseline_throughput.py @@ -0,0 +1,329 @@ +#!/usr/bin/env python3 +""" +Queue & Topic 基线吞吐 Benchmark(单节点) + +在单进程内测量 Queue 与 Topic 的基线吞吐与延迟,便于回归对比。 + +Usage: + python benchmarks/baseline_throughput.py + python benchmarks/baseline_throughput.py --duration 15 --output results.json + python benchmarks/baseline_throughput.py --queue-only + python benchmarks/baseline_throughput.py --topic-only --topic-subscribers 3 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import shutil +import tempfile +import time + +import pulsing as pul +from pulsing.queue import read_queue, write_queue +from pulsing.topic import PublishMode, read_topic, write_topic + + +def _percentile(sorted_data: list[float], p: float) -> float: + if not sorted_data: + return 0.0 + idx = min(int(len(sorted_data) * p / 100), len(sorted_data) - 1) + return sorted_data[idx] + + +# ============================================================================= +# Queue 基线 +# ============================================================================= + + +async def run_queue_baseline( + system, + storage_path: str, + duration: float, + num_buckets: int, + record_size: int, +) -> dict: + """单 writer + 单 reader,固定时长,统计写/读吞吐与延迟.""" + topic = "baseline_queue" + write_latencies_ms: list[float] = [] + read_latencies_ms: list[float] = [] + records_written = 0 + records_read = 0 + + writer = await write_queue( + system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=storage_path, + ) + reader = await read_queue( + system, + topic=topic, + num_buckets=num_buckets, + storage_path=storage_path, + ) + + end_time = time.monotonic() + duration + + async def produce(): + nonlocal records_written + i = 0 + while time.monotonic() < end_time: + t0 = time.perf_counter() + try: + rec = {"id": f"r{i}", "payload": "x" * record_size} + await writer.put(rec) + write_latencies_ms.append((time.perf_counter() - t0) * 1000) + records_written += 1 + i += 1 + except Exception: + pass + + async def consume(): + nonlocal records_read + while time.monotonic() < end_time: + t0 = time.perf_counter() + try: + batch = await reader.get(limit=50, wait=True, timeout=1.0) + if batch: + read_latencies_ms.append((time.perf_counter() - t0) * 1000) + records_read += len(batch) + except asyncio.TimeoutError: + pass + except Exception: + pass + + await asyncio.gather(produce(), consume()) + + write_latencies_ms.sort() + read_latencies_ms.sort() + + return { + "duration_s": duration, + "records_written": records_written, + "records_read": records_read, + "write_throughput_rec_s": records_written / duration if duration > 0 else 0, + "read_throughput_rec_s": records_read / duration if duration > 0 else 0, + "write_latency_ms": { + "avg": sum(write_latencies_ms) / len(write_latencies_ms) + if write_latencies_ms + else 0, + "p50": _percentile(write_latencies_ms, 50), + "p95": _percentile(write_latencies_ms, 95), + "p99": _percentile(write_latencies_ms, 99), + }, + "read_latency_ms": { + "avg": sum(read_latencies_ms) / len(read_latencies_ms) + if read_latencies_ms + else 0, + "p50": _percentile(read_latencies_ms, 50), + "p95": _percentile(read_latencies_ms, 95), + "p99": _percentile(read_latencies_ms, 99), + }, + } + + +# ============================================================================= +# Topic 基线 +# ============================================================================= + + +async def run_topic_baseline( + system, + duration: float, + num_subscribers: int, + payload_size: int, +) -> dict: + """单 publisher + N subscribers,fire_and_forget,统计发布与交付吞吐.""" + topic_name = "baseline_topic" + messages_published = 0 + delivered_per_sub: list[int] = [0] * num_subscribers + publish_latencies_ms: list[float] = [] + + writer = await write_topic(system, topic_name) + readers = [] + locks = [asyncio.Lock() for _ in range(num_subscribers)] + + for i in range(num_subscribers): + reader = await read_topic(system, topic_name, reader_id=f"sub_{i}") + + def make_cb(idx): + async def cb(msg): + async with locks[idx]: + delivered_per_sub[idx] += 1 + + return cb + + reader.add_callback(make_cb(i)) + await reader.start() + readers.append(reader) + + end_time = time.monotonic() + duration + seq = 0 + + while time.monotonic() < end_time: + t0 = time.perf_counter() + try: + await writer.publish( + {"seq": seq, "payload": "x" * payload_size}, + mode=PublishMode.FIRE_AND_FORGET, + ) + publish_latencies_ms.append((time.perf_counter() - t0) * 1000) + messages_published += 1 + seq += 1 + except Exception: + pass + + await asyncio.sleep(0.2) + + for r in readers: + await r.stop() + + publish_latencies_ms.sort() + total_delivered = sum(delivered_per_sub) + + return { + "duration_s": duration, + "num_subscribers": num_subscribers, + "messages_published": messages_published, + "total_delivered": total_delivered, + "publish_throughput_msg_s": messages_published / duration + if duration > 0 + else 0, + "delivered_throughput_msg_s": total_delivered / duration if duration > 0 else 0, + "publish_latency_ms": { + "avg": sum(publish_latencies_ms) / len(publish_latencies_ms) + if publish_latencies_ms + else 0, + "p50": _percentile(publish_latencies_ms, 50), + "p95": _percentile(publish_latencies_ms, 95), + "p99": _percentile(publish_latencies_ms, 99), + }, + } + + +# ============================================================================= +# Main +# ============================================================================= + + +async def main(): + parser = argparse.ArgumentParser( + description="Queue & Topic 基线吞吐 Benchmark(单节点)" + ) + parser.add_argument( + "--duration", + type=float, + default=10.0, + help="每类基准运行时长(秒)", + ) + parser.add_argument( + "--queue-only", + action="store_true", + help="仅跑 Queue 基线", + ) + parser.add_argument( + "--topic-only", + action="store_true", + help="仅跑 Topic 基线", + ) + parser.add_argument( + "--num-buckets", + type=int, + default=4, + help="Queue 桶数", + ) + parser.add_argument( + "--topic-subscribers", + type=int, + default=1, + help="Topic 订阅者数量", + ) + parser.add_argument( + "--record-size", + type=int, + default=100, + help="单条记录 payload 字节数", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="结果写入 JSON 文件路径", + ) + args = parser.parse_args() + + system = await pul.actor_system() + storage_path = tempfile.mkdtemp(prefix="baseline_queue_") + results: dict = {"queue": None, "topic": None} + + try: + if not args.topic_only: + print("Running Queue baseline...") + results["queue"] = await run_queue_baseline( + system, + storage_path=storage_path, + duration=args.duration, + num_buckets=args.num_buckets, + record_size=args.record_size, + ) + + if not args.queue_only: + print("Running Topic baseline...") + results["topic"] = await run_topic_baseline( + system, + duration=args.duration, + num_subscribers=args.topic_subscribers, + payload_size=args.record_size, + ) + finally: + await system.shutdown() + shutil.rmtree(storage_path, ignore_errors=True) + + # 打印汇总 + print() + print("=" * 60) + print("Baseline Throughput Results") + print("=" * 60) + + if results["queue"]: + q = results["queue"] + print("\n--- Queue ---") + print(f" Duration: {q['duration_s']:.1f}s") + print(f" Write throughput: {q['write_throughput_rec_s']:.0f} rec/s") + print(f" Read throughput: {q['read_throughput_rec_s']:.0f} rec/s") + print( + f" Write latency: avg={q['write_latency_ms']['avg']:.2f}ms " + f"p50={q['write_latency_ms']['p50']:.2f}ms p99={q['write_latency_ms']['p99']:.2f}ms" + ) + print( + f" Read latency: avg={q['read_latency_ms']['avg']:.2f}ms " + f"p50={q['read_latency_ms']['p50']:.2f}ms p99={q['read_latency_ms']['p99']:.2f}ms" + ) + + if results["topic"]: + t = results["topic"] + print("\n--- Topic ---") + print(f" Duration: {t['duration_s']:.1f}s") + print(f" Subscribers: {t['num_subscribers']}") + print(f" Publish throughput: {t['publish_throughput_msg_s']:.0f} msg/s") + print( + f" Delivered total: {t['total_delivered']} ({t['delivered_throughput_msg_s']:.0f} msg/s)" + ) + print( + f" Publish latency: avg={t['publish_latency_ms']['avg']:.2f}ms " + f"p50={t['publish_latency_ms']['p50']:.2f}ms p99={t['publish_latency_ms']['p99']:.2f}ms" + ) + + print("\n" + "=" * 60) + + if args.output: + with open(args.output, "w") as f: + json.dump(results, f, indent=2) + print(f"Results written to {args.output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/concurrency_sweep.py b/benchmarks/concurrency_sweep.py new file mode 100644 index 000000000..67b0533d6 --- /dev/null +++ b/benchmarks/concurrency_sweep.py @@ -0,0 +1,337 @@ +#!/usr/bin/env python3 +""" +多并发扫描:在不同 生产者/消费者 并发组合下测量 Queue 与 Topic 吞吐。 + +扫描 (P,C) 组合,例如 P,C ∈ {1,2,4,8},得到每种并发下的写/读吞吐,便于看扩展性。 + +Usage: + python benchmarks/concurrency_sweep.py + python benchmarks/concurrency_sweep.py --producers 1 2 4 8 --consumers 1 2 4 8 --duration 8 + python benchmarks/concurrency_sweep.py --queue-only --output sweep_queue.json +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import shutil +import tempfile +import time + +import pulsing as pul +from pulsing.queue import read_queue, write_queue +from pulsing.topic import PublishMode, read_topic, write_topic + + +# ============================================================================= +# Queue 多并发 +# ============================================================================= + + +async def run_queue_concurrent( + system, + storage_path: str, + num_producers: int, + num_consumers: int, + duration: float, + num_buckets: int, + record_size: int, +) -> dict: + """P 个 writer + C 个 reader (rank/world_size),固定时长,汇总写/读吞吐.""" + topic = f"sweep_q_p{num_producers}_c{num_consumers}" + write_counts = [0] * num_producers + read_counts = [0] * num_consumers + write_locks = [asyncio.Lock() for _ in range(num_producers)] + read_locks = [asyncio.Lock() for _ in range(num_consumers)] + + async def producer(pid: int): + writer = await write_queue( + system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=storage_path, + ) + end_time = time.monotonic() + duration + i = 0 + while time.monotonic() < end_time: + try: + await writer.put({"id": f"p{pid}_{i}", "payload": "x" * record_size}) + async with write_locks[pid]: + write_counts[pid] += 1 + i += 1 + except Exception: + pass + + async def consumer(rank: int): + reader = await read_queue( + system, + topic=topic, + rank=rank, + world_size=num_consumers, + num_buckets=num_buckets, + storage_path=storage_path, + ) + end_time = time.monotonic() + duration + while time.monotonic() < end_time: + try: + batch = await reader.get(limit=50, wait=True, timeout=1.0) + if batch: + async with read_locks[rank]: + read_counts[rank] += len(batch) + except asyncio.TimeoutError: + pass + except Exception: + pass + + await asyncio.gather( + *[producer(p) for p in range(num_producers)], + *[consumer(c) for c in range(num_consumers)], + ) + + total_written = sum(write_counts) + total_read = sum(read_counts) + return { + "num_producers": num_producers, + "num_consumers": num_consumers, + "records_written": total_written, + "records_read": total_read, + "write_throughput_rec_s": total_written / duration if duration > 0 else 0, + "read_throughput_rec_s": total_read / duration if duration > 0 else 0, + "duration_s": duration, + } + + +# ============================================================================= +# Topic 多并发 +# ============================================================================= + + +async def run_topic_concurrent( + system, + num_publishers: int, + num_subscribers: int, + duration: float, + payload_size: int, +) -> dict: + """P 个 publisher + C 个 subscriber,fire_and_forget,汇总发布与交付吞吐.""" + topic_name = f"sweep_t_p{num_publishers}_c{num_subscribers}" + publish_counts = [0] * num_publishers + delivered_counts = [0] * num_subscribers + pub_locks = [asyncio.Lock() for _ in range(num_publishers)] + del_locks = [asyncio.Lock() for _ in range(num_subscribers)] + + # 先起订阅者 + readers = [] + for c in range(num_subscribers): + reader = await read_topic(system, topic_name, reader_id=f"sub_{c}") + + def make_cb(cid): + async def cb(_msg): + async with del_locks[cid]: + delivered_counts[cid] += 1 + + return cb + + reader.add_callback(make_cb(c)) + await reader.start() + readers.append(reader) + + async def publisher(pid: int): + writer = await write_topic(system, topic_name, writer_id=f"pub_{pid}") + end_time = time.monotonic() + duration + seq = 0 + while time.monotonic() < end_time: + try: + await writer.publish( + {"seq": seq, "payload": "x" * payload_size}, + mode=PublishMode.FIRE_AND_FORGET, + ) + async with pub_locks[pid]: + publish_counts[pid] += 1 + seq += 1 + except Exception: + pass + + await asyncio.gather(*[publisher(p) for p in range(num_publishers)]) + + await asyncio.sleep(0.3) + for r in readers: + await r.stop() + + total_published = sum(publish_counts) + total_delivered = sum(delivered_counts) + return { + "num_publishers": num_publishers, + "num_subscribers": num_subscribers, + "messages_published": total_published, + "messages_delivered": total_delivered, + "publish_throughput_msg_s": total_published / duration if duration > 0 else 0, + "delivered_throughput_msg_s": total_delivered / duration if duration > 0 else 0, + "duration_s": duration, + } + + +# ============================================================================= +# 扫描与输出 +# ============================================================================= + + +def parse_concurrency(s: str) -> list[int]: + """解析如 '1,2,4,8' 或 '1-4' 为整数列表.""" + s = s.strip() + if "-" in s: + a, b = s.split("-", 1) + return list(range(int(a), int(b) + 1)) + return [int(x) for x in s.replace(",", " ").split()] + + +async def main(): + parser = argparse.ArgumentParser( + description="多并发扫描:不同 生产者/消费者 下的 Queue & Topic 吞吐" + ) + parser.add_argument( + "--producers", + type=str, + default="1,2,4,8", + help="生产者并发数列表,如 1,2,4,8 或 1-4", + ) + parser.add_argument( + "--consumers", + type=str, + default="1,2,4,8", + help="消费者并发数列表,如 1,2,4,8 或 1-4", + ) + parser.add_argument( + "--duration", + type=float, + default=8.0, + help="每个 (P,C) 组合运行时长(秒)", + ) + parser.add_argument( + "--num-buckets", + type=int, + default=8, + help="Queue 桶数", + ) + parser.add_argument( + "--record-size", + type=int, + default=100, + help="单条记录 payload 字节数", + ) + parser.add_argument( + "--queue-only", + action="store_true", + help="仅扫描 Queue", + ) + parser.add_argument( + "--topic-only", + action="store_true", + help="仅扫描 Topic", + ) + parser.add_argument( + "--output", + type=str, + default=None, + help="结果 JSON 文件路径", + ) + args = parser.parse_args() + + prods = parse_concurrency(args.producers) + cons = parse_concurrency(args.consumers) + + system = await pul.actor_system() + storage_path = tempfile.mkdtemp(prefix="concurrency_sweep_") + results = {"queue": [], "topic": []} + + try: + if not args.topic_only: + print("Queue concurrency sweep...") + for P in prods: + for C in cons: + if C > args.num_buckets: + continue + r = await run_queue_concurrent( + system, + storage_path=storage_path, + num_producers=P, + num_consumers=C, + duration=args.duration, + num_buckets=args.num_buckets, + record_size=args.record_size, + ) + results["queue"].append(r) + print( + f" P={P} C={C} write={r['write_throughput_rec_s']:.0f} rec/s read={r['read_throughput_rec_s']:.0f} rec/s" + ) + + if not args.queue_only: + print("Topic concurrency sweep...") + for P in prods: + for C in cons: + r = await run_topic_concurrent( + system, + num_publishers=P, + num_subscribers=C, + duration=args.duration, + payload_size=args.record_size, + ) + results["topic"].append(r) + print( + f" P={P} C={C} publish={r['publish_throughput_msg_s']:.0f} msg/s delivered={r['delivered_throughput_msg_s']:.0f} msg/s" + ) + finally: + await system.shutdown() + shutil.rmtree(storage_path, ignore_errors=True) + + # 打印汇总表 + print() + print("=" * 72) + print("Concurrency Sweep Summary") + print("=" * 72) + + if results["queue"]: + print("\n--- Queue (rec/s) ---") + print(f"{'P':>3} {'C':>3} {'write_rec/s':>12} {'read_rec/s':>12}") + print("-" * 40) + for r in results["queue"]: + print( + f"{r['num_producers']:>3} {r['num_consumers']:>3} " + f"{r['write_throughput_rec_s']:>12.0f} {r['read_throughput_rec_s']:>12.0f}" + ) + + if results["topic"]: + print("\n--- Topic (msg/s) ---") + print(f"{'P':>3} {'C':>3} {'publish_msg/s':>14} {'delivered_msg/s':>16}") + print("-" * 50) + for r in results["topic"]: + print( + f"{r['num_publishers']:>3} {r['num_subscribers']:>3} " + f"{r['publish_throughput_msg_s']:>14.0f} {r['delivered_throughput_msg_s']:>16.0f}" + ) + + print("\n" + "=" * 72) + + if args.output: + with open(args.output, "w") as f: + json.dump( + { + "config": { + "duration_s": args.duration, + "num_buckets": args.num_buckets, + "record_size": args.record_size, + "producers": prods, + "consumers": cons, + }, + "results": results, + }, + f, + indent=2, + ) + print(f"Results written to {args.output}") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/benchmarks/run_baseline_throughput.sh b/benchmarks/run_baseline_throughput.sh new file mode 100644 index 000000000..c07be5d41 --- /dev/null +++ b/benchmarks/run_baseline_throughput.sh @@ -0,0 +1,31 @@ +#!/bin/bash +# +# 运行 Queue & Topic 基线吞吐 Benchmark(单节点) +# +# Usage: +# ./benchmarks/run_baseline_throughput.sh +# DURATION=15 ./benchmarks/run_baseline_throughput.sh +# python benchmarks/baseline_throughput.py --topic-only --topic-subscribers 3 +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$PROJECT_ROOT" + +DURATION=${DURATION:-10} +OUTPUT=${OUTPUT:-} + +echo "==========================================" +echo "Baseline Throughput (Queue + Topic)" +echo "==========================================" +echo "Duration: ${DURATION}s per benchmark" +echo "==========================================" + +if [ -n "$OUTPUT" ]; then + python benchmarks/baseline_throughput.py --duration "$DURATION" --output "$OUTPUT" +else + python benchmarks/baseline_throughput.py --duration "$DURATION" +fi diff --git a/benchmarks/run_concurrency_sweep.sh b/benchmarks/run_concurrency_sweep.sh new file mode 100644 index 000000000..e909750e9 --- /dev/null +++ b/benchmarks/run_concurrency_sweep.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# +# 多并发扫描:不同 生产者/消费者 组合下的吞吐 +# +# Usage: +# ./benchmarks/run_concurrency_sweep.sh +# DURATION=5 ./benchmarks/run_concurrency_sweep.sh +# python benchmarks/concurrency_sweep.py --producers 1 2 4 --consumers 1 2 4 --output sweep.json +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$PROJECT_ROOT" + +DURATION=${DURATION:-8} +OUTPUT=${OUTPUT:-} + +echo "==========================================" +echo "Concurrency Sweep (P producers, C consumers)" +echo "==========================================" +echo "Duration per (P,C): ${DURATION}s" +echo "Producers/Consumers: 1,2,4,8 (default)" +echo "==========================================" + +if [ -n "$OUTPUT" ]; then + python benchmarks/concurrency_sweep.py --duration "$DURATION" --output "$OUTPUT" +else + python benchmarks/concurrency_sweep.py --duration "$DURATION" +fi diff --git a/benchmarks/run_stress_multiprocessing.sh b/benchmarks/run_stress_multiprocessing.sh new file mode 100644 index 000000000..638b3c746 --- /dev/null +++ b/benchmarks/run_stress_multiprocessing.sh @@ -0,0 +1,32 @@ +#!/bin/bash +# +# 多进程压测(multiprocessing,不依赖 torchrun) +# +# Usage: +# ./benchmarks/run_stress_multiprocessing.sh +# NPROCS=8 DURATION=30 ./benchmarks/run_stress_multiprocessing.sh +# + +set -e + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" + +cd "$PROJECT_ROOT" + +NPROCS=${NPROCS:-4} +DURATION=${DURATION:-20} +OUTPUT=${OUTPUT:-} + +echo "==========================================" +echo "Multiprocessing Stress (Queue + Topic)" +echo "==========================================" +echo "Processes: $NPROCS" +echo "Duration: ${DURATION}s" +echo "==========================================" + +if [ -n "$OUTPUT" ]; then + python benchmarks/stress_multiprocessing.py --nprocs "$NPROCS" --duration "$DURATION" --output "$OUTPUT" +else + python benchmarks/stress_multiprocessing.py --nprocs "$NPROCS" --duration "$DURATION" +fi diff --git a/benchmarks/stress_multiprocessing.py b/benchmarks/stress_multiprocessing.py new file mode 100644 index 000000000..07e045e36 --- /dev/null +++ b/benchmarks/stress_multiprocessing.py @@ -0,0 +1,313 @@ +#!/usr/bin/env python3 +""" +多进程压测:用 multiprocessing 启动 N 个进程,每进程一个 ActorSystem 节点,组成集群, +对 Queue 与 Topic 做多进程压力测试(不依赖 torchrun)。 + +Usage: + python benchmarks/stress_multiprocessing.py --nprocs 4 --duration 20 + python benchmarks/stress_multiprocessing.py --nprocs 8 --queue-only --num-buckets 16 +""" + +from __future__ import annotations + +import argparse +import asyncio +import json +import multiprocessing as mp +import os +import shutil +import tempfile +import time +from multiprocessing import Queue + +import pulsing as pul +from pulsing.queue import read_queue, write_queue +from pulsing.topic import PublishMode, read_topic, write_topic + + +# ============================================================================= +# 单进程内异步压测逻辑(Queue + Topic) +# ============================================================================= + + +async def _run_worker( + rank: int, + world_size: int, + seed_queue: Queue, + result_queue: Queue, + config: dict, +) -> None: + """单进程 worker:加入集群后跑 Queue 与 Topic 压测,结果放入 result_queue.""" + base_port = config.get("base_port", 9100) + duration = config["duration"] + num_buckets = config["num_buckets"] + record_size = config["record_size"] + storage_path = config["storage_path"] + run_queue_bench = config.get("queue", True) + run_topic_bench = config.get("topic", True) + stabilize = config.get("stabilize_timeout", 4.0) + + # rank 0 先起节点并 bind,再往 queue 里放 seed,供其余 (world_size-1) 个进程各 get 一次 + system = None + if rank == 0: + addr = f"127.0.0.1:{base_port}" + system = await pul.actor_system(addr=addr) + for _ in range(world_size - 1): + seed_queue.put(addr) + else: + seed_addr = seed_queue.get(timeout=60.0) + await asyncio.sleep(0.5) + system = await pul.actor_system( + addr=f"127.0.0.1:{base_port + rank}", + seeds=[seed_addr], + ) + + await asyncio.sleep(stabilize) + + result = {"rank": rank, "world_size": world_size, "queue": None, "topic": None} + + if run_queue_bench: + topic_q = f"mp_queue_{world_size}" + writer = await write_queue( + system, + topic=topic_q, + bucket_column="id", + num_buckets=num_buckets, + storage_path=storage_path, + ) + reader = await read_queue( + system, + topic=topic_q, + rank=rank, + world_size=world_size, + num_buckets=num_buckets, + storage_path=storage_path, + ) + written = 0 + read_count = 0 + end_time = time.monotonic() + duration + while time.monotonic() < end_time: + try: + await writer.put( + {"id": f"r{rank}_{written}", "payload": "x" * record_size} + ) + written += 1 + except Exception: + pass + try: + await writer.flush() + except Exception: + pass + while time.monotonic() < end_time + 2: + try: + batch = await reader.get(limit=100, wait=True, timeout=0.5) + if batch: + read_count += len(batch) + else: + break + except asyncio.TimeoutError: + break + except Exception: + break + result["queue"] = { + "records_written": written, + "records_read": read_count, + "write_rec_s": written / duration if duration > 0 else 0, + "read_rec_s": read_count / duration if duration > 0 else 0, + } + + if run_topic_bench: + topic_t = f"mp_topic_{world_size}" + recv_count = [0] + reader = await read_topic(system, topic_t, reader_id=f"mp_{rank}") + + def on_msg(_): + recv_count[0] += 1 + + reader.add_callback(on_msg) + await reader.start() + writer = await write_topic(system, topic_t, writer_id=f"mp_{rank}") + published = 0 + end_time = time.monotonic() + duration + while time.monotonic() < end_time: + try: + await writer.publish( + {"seq": published, "payload": "x" * record_size}, + mode=PublishMode.FIRE_AND_FORGET, + ) + published += 1 + except Exception: + pass + await asyncio.sleep(0.5) + await reader.stop() + result["topic"] = { + "messages_published": published, + "messages_delivered": recv_count[0], + "publish_msg_s": published / duration if duration > 0 else 0, + "delivered_msg_s": recv_count[0] / duration if duration > 0 else 0, + } + + await system.shutdown() + result_queue.put(result) + + +def _worker_entry( + rank: int, + world_size: int, + seed_queue: Queue, + result_queue: Queue, + config: dict, +) -> None: + asyncio.run(_run_worker(rank, world_size, seed_queue, result_queue, config)) + + +# ============================================================================= +# Main:spawn 多进程并汇总 +# ============================================================================= + + +def main(): + parser = argparse.ArgumentParser( + description="多进程 Queue & Topic 压测(multiprocessing)" + ) + parser.add_argument("--nprocs", type=int, default=4, help="进程数(集群节点数)") + parser.add_argument("--duration", type=float, default=20.0, help="压测时长(秒)") + parser.add_argument("--num-buckets", type=int, default=8, help="Queue 桶数") + parser.add_argument( + "--record-size", type=int, default=100, help="单条 payload 字节数" + ) + parser.add_argument( + "--base-port", type=int, default=9100, help="首节点端口,其余 base_port+rank" + ) + parser.add_argument( + "--stabilize-timeout", type=float, default=4.0, help="集群稳定等待时间" + ) + parser.add_argument("--queue-only", action="store_true", help="仅压 Queue") + parser.add_argument("--topic-only", action="store_true", help="仅压 Topic") + parser.add_argument("--output", type=str, default=None, help="汇总结果 JSON 路径") + args = parser.parse_args() + + nprocs = args.nprocs + run_queue = not args.topic_only + run_topic = not args.queue_only + + storage_path = os.path.abspath(tempfile.mkdtemp(prefix="stress_mp_")) + config = { + "duration": args.duration, + "num_buckets": args.num_buckets, + "record_size": args.record_size, + "storage_path": storage_path, + "base_port": args.base_port, + "stabilize_timeout": args.stabilize_timeout, + "queue": run_queue, + "topic": run_topic, + } + + seed_queue = mp.Queue() + result_queue = mp.Queue() + + procs = [] + for rank in range(nprocs): + p = mp.Process( + target=_worker_entry, + args=(rank, nprocs, seed_queue, result_queue, config), + ) + procs.append(p) + + print("Starting %d processes (cluster size=%d)..." % (nprocs, nprocs)) + for p in procs: + p.start() + for p in procs: + p.join(timeout=args.duration + 60) + if p.is_alive(): + p.terminate() + p.join(timeout=5) + + shutil.rmtree(storage_path, ignore_errors=True) + + results = [] + while not result_queue.empty(): + try: + results.append(result_queue.get_nowait()) + except Exception: + break + + if len(results) != nprocs: + print("WARNING: got %d results, expected %d" % (len(results), nprocs)) + + # 汇总 + print() + print("=" * 60) + print("Multiprocessing Stress Summary (cluster size=%d)" % nprocs) + print("=" * 60) + + agg = {"queue": None, "topic": None} + if run_queue and results: + total_write = sum( + r["queue"]["records_written"] for r in results if r.get("queue") + ) + total_read = sum(r["queue"]["records_read"] for r in results if r.get("queue")) + dur = args.duration + agg["queue"] = { + "total_records_written": total_write, + "total_records_read": total_read, + "write_throughput_rec_s": total_write / dur if dur > 0 else 0, + "read_throughput_rec_s": total_read / dur if dur > 0 else 0, + } + print("\n--- Queue ---") + print( + " Total written: %d (%.0f rec/s)" + % (total_write, agg["queue"]["write_throughput_rec_s"]) + ) + print( + " Total read: %d (%.0f rec/s)" + % (total_read, agg["queue"]["read_throughput_rec_s"]) + ) + + if run_topic and results: + total_pub = sum( + r["topic"]["messages_published"] for r in results if r.get("topic") + ) + total_del = sum( + r["topic"]["messages_delivered"] for r in results if r.get("topic") + ) + dur = args.duration + agg["topic"] = { + "total_published": total_pub, + "total_delivered": total_del, + "publish_throughput_msg_s": total_pub / dur if dur > 0 else 0, + "delivered_throughput_msg_s": total_del / dur if dur > 0 else 0, + } + print("\n--- Topic ---") + print( + " Total published: %d (%.0f msg/s)" + % (total_pub, agg["topic"]["publish_throughput_msg_s"]) + ) + print( + " Total delivered: %d (%.0f msg/s)" + % (total_del, agg["topic"]["delivered_throughput_msg_s"]) + ) + + print("\n" + "=" * 60) + + if args.output: + with open(args.output, "w") as f: + json.dump( + { + "nprocs": nprocs, + "config": { + "duration": args.duration, + "num_buckets": args.num_buckets, + "record_size": args.record_size, + }, + "aggregate": agg, + "per_rank": results, + }, + f, + indent=2, + ) + print("Results written to %s" % args.output) + + +if __name__ == "__main__": + main() diff --git a/crates/pulsing-actor/Cargo.toml b/crates/pulsing-actor/Cargo.toml index 6950f25f5..f0fc8630f 100644 --- a/crates/pulsing-actor/Cargo.toml +++ b/crates/pulsing-actor/Cargo.toml @@ -97,9 +97,6 @@ path = "../../examples/rust/behavior_counter.rs" name = "behavior_fsm" path = "../../examples/rust/behavior_fsm.rs" -[[test]] -name = "integration" -path = "tests/integration_tests.rs" [lints] workspace = true diff --git a/crates/pulsing-actor/src/actor/address.rs b/crates/pulsing-actor/src/actor/address.rs index 30890f4da..f802b46f0 100644 --- a/crates/pulsing-actor/src/actor/address.rs +++ b/crates/pulsing-actor/src/actor/address.rs @@ -239,31 +239,39 @@ impl TryFrom for ActorPath { } } +impl From for crate::error::PulsingError { + fn from(err: AddressParseError) -> Self { + crate::error::PulsingError::from(crate::error::RuntimeError::invalid_actor_path( + err.to_string(), + )) + } +} + /// Trait for types that can be converted to ActorPath pub trait IntoActorPath { - fn into_actor_path(self) -> anyhow::Result; + fn into_actor_path(self) -> crate::error::Result; } impl IntoActorPath for &str { - fn into_actor_path(self) -> anyhow::Result { + fn into_actor_path(self) -> crate::error::Result { ActorPath::new(self).map_err(Into::into) } } impl IntoActorPath for String { - fn into_actor_path(self) -> anyhow::Result { + fn into_actor_path(self) -> crate::error::Result { ActorPath::new(self).map_err(Into::into) } } impl IntoActorPath for ActorPath { - fn into_actor_path(self) -> anyhow::Result { + fn into_actor_path(self) -> crate::error::Result { Ok(self) } } impl IntoActorPath for &ActorPath { - fn into_actor_path(self) -> anyhow::Result { + fn into_actor_path(self) -> crate::error::Result { Ok(self.clone()) } } diff --git a/crates/pulsing-actor/src/actor/context.rs b/crates/pulsing-actor/src/actor/context.rs index 4c22f81ce..315cacf07 100644 --- a/crates/pulsing-actor/src/actor/context.rs +++ b/crates/pulsing-actor/src/actor/context.rs @@ -24,13 +24,13 @@ pub struct ActorContext { /// Trait for system reference. #[async_trait::async_trait] pub trait ActorSystemRef: Send + Sync { - async fn actor_ref(&self, id: &ActorId) -> anyhow::Result; + async fn actor_ref(&self, id: &ActorId) -> crate::error::Result; fn node_id(&self) -> NodeId; - async fn watch(&self, watcher: &ActorId, target: &ActorId) -> anyhow::Result<()>; + async fn watch(&self, watcher: &ActorId, target: &ActorId) -> crate::error::Result<()>; - async fn unwatch(&self, watcher: &ActorId, target: &ActorId) -> anyhow::Result<()>; + async fn unwatch(&self, watcher: &ActorId, target: &ActorId) -> crate::error::Result<()>; fn local_actor_ref_by_name(&self, name: &str) -> Option; } @@ -102,7 +102,7 @@ impl ActorContext { self.cancel_token.is_cancelled() } - pub async fn actor_ref(&mut self, id: &ActorId) -> anyhow::Result { + pub async fn actor_ref(&mut self, id: &ActorId) -> crate::error::Result { if let Some(r) = self.actor_refs.get(id) { return Ok(r.clone()); } @@ -117,7 +117,7 @@ impl ActorContext { &self, msg: M, delay: Duration, - ) -> anyhow::Result<()> { + ) -> crate::error::Result<()> { let sender = self.self_sender.clone(); let message = Message::pack(&msg)?; @@ -133,12 +133,12 @@ impl ActorContext { } /// Watch another actor. - pub async fn watch(&self, target: &ActorId) -> anyhow::Result<()> { + pub async fn watch(&self, target: &ActorId) -> crate::error::Result<()> { self.system.watch(&self.actor_id, target).await } /// Stop watching another actor. - pub async fn unwatch(&self, target: &ActorId) -> anyhow::Result<()> { + pub async fn unwatch(&self, target: &ActorId) -> crate::error::Result<()> { self.system.unwatch(&self.actor_id, target).await } } diff --git a/crates/pulsing-actor/src/actor/mailbox.rs b/crates/pulsing-actor/src/actor/mailbox.rs index 1c4d78930..8724d1195 100644 --- a/crates/pulsing-actor/src/actor/mailbox.rs +++ b/crates/pulsing-actor/src/actor/mailbox.rs @@ -1,16 +1,17 @@ //! Actor mailbox - message envelope and queue. use super::traits::Message; +use crate::error::{PulsingError, Result, RuntimeError}; use tokio::sync::{mpsc, oneshot}; /// Response channel type. -pub type ResponseChannel = oneshot::Sender>; +pub type ResponseChannel = oneshot::Sender>; /// Responder - sends response back to caller (no-op for tell pattern). pub struct Responder(Option); impl Responder { - pub fn send(self, result: anyhow::Result) { + pub fn send(self, result: Result) { if let Some(tx) = self.0 { let _ = tx.send(result); } @@ -111,17 +112,17 @@ impl MailboxSender { Self { inner: sender } } - pub async fn send(&self, envelope: Envelope) -> anyhow::Result<()> { + pub async fn send(&self, envelope: Envelope) -> Result<()> { self.inner .send(envelope) .await - .map_err(|_| anyhow::anyhow!("Mailbox closed")) + .map_err(|_| PulsingError::from(RuntimeError::Other("Mailbox closed".into()))) } - pub fn try_send(&self, envelope: Envelope) -> anyhow::Result<()> { - self.inner - .try_send(envelope) - .map_err(|e| anyhow::anyhow!("Mailbox send failed: {}", e)) + pub fn try_send(&self, envelope: Envelope) -> Result<()> { + self.inner.try_send(envelope).map_err(|e| { + PulsingError::from(RuntimeError::Other(format!("Mailbox send failed: {}", e))) + }) } pub fn is_closed(&self) -> bool { diff --git a/crates/pulsing-actor/src/actor/mod.rs b/crates/pulsing-actor/src/actor/mod.rs index f963206a1..70c34b264 100644 --- a/crates/pulsing-actor/src/actor/mod.rs +++ b/crates/pulsing-actor/src/actor/mod.rs @@ -1,4 +1,130 @@ //! Actor module - core abstractions. +//! +//! This module provides the fundamental building blocks for the actor system: +//! - [`Actor`] trait - implement this for your actor types +//! - [`ActorRef`] - handle for sending messages to actors +//! - [`ActorContext`] - context passed to actor handlers +//! - [`ActorPath`] and [`ActorAddress`] - actor addressing +//! - [`Message`] - message envelope type +//! +//! # Examples +//! +//! ## Creating a Simple Actor +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::error::PulsingError; +//! use serde::{Deserialize, Serialize}; +//! +//! // Define your messages +//! #[derive(Serialize, Deserialize, Debug)] +//! struct Greet { name: String } +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! struct Greeting { message: String } +//! +//! // Define your actor +//! struct Greeter; +//! +//! #[async_trait] +//! impl Actor for Greeter { +//! async fn receive( +//! &mut self, +//! msg: Message, +//! _ctx: &mut ActorContext, +//! ) -> Result { +//! if let Ok(greet) = msg.unpack::() { +//! let response = Greeting { +//! message: format!("Hello, {}!", greet.name), +//! }; +//! return Message::pack(&response); +//! } +//! Err(PulsingError::from( +//! pulsing_actor::error::RuntimeError::Other("Unknown message".into()) +//! )) +//! } +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! let greeter = system.spawn_named("services/greeter", Greeter).await?; +//! +//! let greeting: Greeting = greeter +//! .ask(Greet { name: "World".into() }) +//! .await?; +//! +//! println!("{}", greeting.message); // Hello, World! +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Using Actor Paths +//! +//! ``` +//! use pulsing_actor::actor::ActorPath; +//! +//! // Create a path +//! let path = ActorPath::new("services/api/users").unwrap(); +//! assert_eq!(path.namespace(), "services"); +//! assert_eq!(path.name(), "users"); +//! +//! // Get parent path +//! let parent = path.parent().unwrap(); +//! assert_eq!(parent.as_str(), "services/api"); +//! +//! // Create child path +//! let child = path.child("profile").unwrap(); +//! assert_eq!(child.as_str(), "services/api/users/profile"); +//! ``` +//! +//! ## Handling Different Message Types +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::error::PulsingError; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! enum CalculatorMsg { +//! Add(i32), +//! Subtract(i32), +//! GetResult, +//! } +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! struct CountResult(i32); +//! +//! struct Calculator { value: i32 } +//! +//! #[async_trait] +//! impl Actor for Calculator { +//! async fn receive( +//! &mut self, +//! msg: Message, +//! _ctx: &mut ActorContext, +//! ) -> Result { +//! match msg.unpack::() { +//! Ok(CalculatorMsg::Add(n)) => { +//! self.value += n; +//! Message::pack(&CountResult(self.value)) +//! } +//! Ok(CalculatorMsg::Subtract(n)) => { +//! self.value -= n; +//! Message::pack(&CountResult(self.value)) +//! } +//! Ok(CalculatorMsg::GetResult) => { +//! Message::pack(&CountResult(self.value)) +//! } +//! Err(_e) => Err(PulsingError::from( +//! pulsing_actor::error::RuntimeError::Serialization( +//! "Failed to unpack message".into() +//! ) +//! )), +//! } +//! } +//! } +//! ``` mod address; mod context; diff --git a/crates/pulsing-actor/src/actor/reference.rs b/crates/pulsing-actor/src/actor/reference.rs index 5f2a32729..f11b61c27 100644 --- a/crates/pulsing-actor/src/actor/reference.rs +++ b/crates/pulsing-actor/src/actor/reference.rs @@ -3,6 +3,7 @@ use super::address::ActorPath; use super::mailbox::Envelope; use super::traits::{ActorId, Message}; +use crate::error::{PulsingError, Result, RuntimeError}; use serde::{de::DeserializeOwned, Serialize}; use std::net::SocketAddr; use std::sync::Arc; @@ -54,7 +55,7 @@ const CACHE_TTL: std::time::Duration = std::time::Duration::from_secs(5); /// Trait for resolving actor paths to ActorRefs. #[async_trait::async_trait] pub trait ActorResolver: Send + Sync { - async fn resolve_path(&self, path: &ActorPath) -> anyhow::Result; + async fn resolve_path(&self, path: &ActorPath) -> crate::error::Result; } impl LazyActorRef { @@ -67,7 +68,7 @@ impl LazyActorRef { } } - async fn get(&self) -> anyhow::Result { + async fn get(&self) -> Result { { let cache = self.cache.read().await; if let Some(ref cached) = *cache { @@ -113,29 +114,26 @@ pub trait RemoteTransport: Send + Sync { actor_id: &ActorId, msg_type: &str, payload: Vec, - ) -> anyhow::Result>; + ) -> Result>; - async fn send( - &self, - actor_id: &ActorId, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result<()>; + async fn send(&self, actor_id: &ActorId, msg_type: &str, payload: Vec) -> Result<()>; /// Send a message and receive response (unified interface). - async fn send_message(&self, actor_id: &ActorId, msg: Message) -> anyhow::Result { + async fn send_message(&self, actor_id: &ActorId, msg: Message) -> Result { let Message::Single { msg_type, data } = msg else { - return Err(anyhow::anyhow!("Streaming requests not yet supported")); + return Err(PulsingError::from(RuntimeError::Other( + "Streaming requests not yet supported".into(), + ))); }; let response = self.request(actor_id, &msg_type, data).await?; Ok(Message::single("", response)) } - async fn send_oneway(&self, actor_id: &ActorId, msg: Message) -> anyhow::Result<()> { + async fn send_oneway(&self, actor_id: &ActorId, msg: Message) -> Result<()> { let Message::Single { msg_type, data } = msg else { - return Err(anyhow::anyhow!( - "Streaming not supported for fire-and-forget" - )); + return Err(PulsingError::from(RuntimeError::Other( + "Streaming not supported for fire-and-forget".into(), + ))); }; self.send(actor_id, &msg_type, data).await } @@ -201,15 +199,15 @@ impl ActorRef { /// /// Use this when you need direct access to `Message`, e.g., for streaming. /// For type-safe communication, prefer `ask()` and `tell()`. - pub async fn send(&self, msg: Message) -> anyhow::Result { + pub async fn send(&self, msg: Message) -> Result { match &self.inner { ActorRefInner::Local(sender) => { let (tx, rx) = oneshot::channel(); - sender - .send(Envelope::ask(msg, tx)) - .await - .map_err(|_| anyhow::anyhow!("Actor mailbox closed"))?; - rx.await.map_err(|_| anyhow::anyhow!("Actor dropped"))? + sender.send(Envelope::ask(msg, tx)).await.map_err(|_| { + PulsingError::from(RuntimeError::Other("Actor mailbox closed".into())) + })?; + rx.await + .map_err(|_| PulsingError::from(RuntimeError::Other("Actor dropped".into())))? } ActorRefInner::Remote(remote) => { remote.transport.send_message(&self.actor_id, msg).await @@ -224,12 +222,11 @@ impl ActorRef { } /// Send a raw message without waiting for response (low-level fire-and-forget) - pub async fn send_oneway(&self, msg: Message) -> anyhow::Result<()> { + pub async fn send_oneway(&self, msg: Message) -> Result<()> { match &self.inner { - ActorRefInner::Local(sender) => sender - .send(Envelope::tell(msg)) - .await - .map_err(|_| anyhow::anyhow!("Actor mailbox closed")), + ActorRefInner::Local(sender) => sender.send(Envelope::tell(msg)).await.map_err(|_| { + PulsingError::from(RuntimeError::Other("Actor mailbox closed".into())) + }), ActorRefInner::Remote(remote) => { remote.transport.send_oneway(&self.actor_id, msg).await } @@ -248,7 +245,7 @@ impl ActorRef { /// ```ignore /// let pong: Pong = actor.ask(Ping { value: 42 }).await?; /// ``` - pub async fn ask(&self, msg: M) -> anyhow::Result + pub async fn ask(&self, msg: M) -> Result where M: Serialize + 'static, R: DeserializeOwned, @@ -262,7 +259,7 @@ impl ActorRef { /// ```ignore /// actor.tell(Ping { value: 42 }).await?; /// ``` - pub async fn tell(&self, msg: M) -> anyhow::Result<()> + pub async fn tell(&self, msg: M) -> Result<()> where M: Serialize + 'static, { @@ -285,12 +282,18 @@ impl std::fmt::Debug for ActorRef { #[cfg(test)] mod tests { use super::*; + use crate::actor::ActorPath; #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)] struct TestMsg { value: i32, } + #[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq)] + struct TestReply { + result: i32, + } + #[tokio::test] async fn test_local_actor_ref_tell() { let (tx, mut rx) = mpsc::channel(16); @@ -300,7 +303,6 @@ mod tests { actor_ref.tell(TestMsg { value: 42 }).await.unwrap(); let envelope = rx.recv().await.unwrap(); - // type_name includes module path assert!(envelope.msg_type().ends_with("TestMsg")); } @@ -317,4 +319,219 @@ mod tests { assert_eq!(envelope.msg_type(), "TestMsg"); assert!(!envelope.expects_response()); } + + #[tokio::test] + async fn test_local_actor_ref_ask_success() { + let (tx, mut rx) = mpsc::channel(16); + let actor_id = ActorId::generate(); + let actor_ref = ActorRef::local(actor_id, tx.clone()); + + let reply_handle = tokio::spawn(async move { + let envelope = rx.recv().await.unwrap(); + let (msg, responder) = envelope.into_parts(); + let req: TestMsg = msg.unpack().unwrap(); + responder.send(Ok(Message::pack(&TestReply { + result: req.value * 2, + }) + .unwrap())); + }); + + let reply: TestReply = actor_ref.ask(TestMsg { value: 21 }).await.unwrap(); + assert_eq!(reply.result, 42); + reply_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_local_actor_ref_mailbox_closed() { + let (tx, rx) = mpsc::channel(16); + let actor_id = ActorId::generate(); + let actor_ref = ActorRef::local(actor_id, tx); + drop(rx); + + let err = actor_ref.tell(TestMsg { value: 1 }).await.unwrap_err(); + assert!( + err.to_string().to_lowercase().contains("mailbox") + || err.to_string().to_lowercase().contains("closed") + ); + } + + #[tokio::test] + async fn test_actor_ref_is_local_is_lazy() { + let (tx, _rx) = mpsc::channel(16); + let local_ref = ActorRef::local(ActorId::generate(), tx); + assert!(local_ref.is_local()); + assert!(!local_ref.is_lazy()); + + let path = ActorPath::new("a/b").unwrap(); + struct MockResolver; + #[async_trait::async_trait] + impl ActorResolver for MockResolver { + async fn resolve_path(&self, _path: &ActorPath) -> Result { + Err(PulsingError::from(RuntimeError::actor_not_found("mock"))) + } + } + let lazy_ref = ActorRef::lazy(path, Arc::new(MockResolver)); + assert!(!lazy_ref.is_local()); + assert!(lazy_ref.is_lazy()); + } + + #[tokio::test] + async fn test_remote_actor_ref_delegates_to_transport() { + struct MockTransport { + request_result: Result>, + send_result: Result<()>, + } + #[async_trait::async_trait] + impl RemoteTransport for MockTransport { + async fn request( + &self, + _actor_id: &ActorId, + _msg_type: &str, + _payload: Vec, + ) -> Result> { + self.request_result.clone() + } + async fn send( + &self, + _actor_id: &ActorId, + _msg_type: &str, + _payload: Vec, + ) -> Result<()> { + self.send_result.clone() + } + } + + let transport = Arc::new(MockTransport { + request_result: Ok(bincode::serialize(&TestReply { result: 100 }).unwrap()), + send_result: Ok(()), + }); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let remote_ref = ActorRef::remote(ActorId::generate(), addr, transport); + + let reply: TestReply = remote_ref.ask(TestMsg { value: 50 }).await.unwrap(); + assert_eq!(reply.result, 100); + remote_ref.tell(TestMsg { value: 0 }).await.unwrap(); + } + + #[tokio::test] + async fn test_remote_actor_ref_transport_error() { + struct FailingTransport; + #[async_trait::async_trait] + impl RemoteTransport for FailingTransport { + async fn request(&self, _: &ActorId, _: &str, _: Vec) -> Result> { + Err(PulsingError::from(RuntimeError::connection_failed( + "127.0.0.1:1".to_string(), + "refused".to_string(), + ))) + } + async fn send(&self, _: &ActorId, _: &str, _: Vec) -> Result<()> { + Err(PulsingError::from(RuntimeError::connection_failed( + "127.0.0.1:1".to_string(), + "refused".to_string(), + ))) + } + } + let addr: SocketAddr = "127.0.0.1:1".parse().unwrap(); + let remote_ref = ActorRef::remote(ActorId::generate(), addr, Arc::new(FailingTransport)); + + let err = remote_ref + .ask::(TestMsg { value: 1 }) + .await + .unwrap_err(); + assert!( + err.to_string().to_lowercase().contains("connection") + || err.to_string().to_lowercase().contains("refused") + ); + } + + #[tokio::test] + async fn test_lazy_ref_resolves_and_caches() { + let (tx, mut rx) = mpsc::channel(16); + let actor_id = ActorId::generate(); + let local_ref = ActorRef::local(actor_id, tx.clone()); + + struct ResolverThatReturns { + ref_to_return: ActorRef, + } + #[async_trait::async_trait] + impl ActorResolver for ResolverThatReturns { + async fn resolve_path(&self, _path: &ActorPath) -> Result { + Ok(self.ref_to_return.clone()) + } + } + + let path = ActorPath::new("svc/echo").unwrap(); + let lazy_ref = ActorRef::lazy( + path, + Arc::new(ResolverThatReturns { + ref_to_return: local_ref, + }), + ); + + let reply_handle = tokio::spawn(async move { + for _ in 0..2 { + let Some(envelope) = rx.recv().await else { + break; + }; + let (msg, responder) = envelope.into_parts(); + let req: TestMsg = msg.unpack().unwrap(); + responder.send(Ok(Message::pack(&TestReply { + result: req.value + 1, + }) + .unwrap())); + } + }); + + let r1: TestReply = lazy_ref.ask(TestMsg { value: 10 }).await.unwrap(); + assert_eq!(r1.result, 11); + let r2: TestReply = lazy_ref.ask(TestMsg { value: 20 }).await.unwrap(); + assert_eq!(r2.result, 21); + reply_handle.await.unwrap(); + } + + #[tokio::test] + async fn test_lazy_ref_resolve_failure() { + struct FailingResolver; + #[async_trait::async_trait] + impl ActorResolver for FailingResolver { + async fn resolve_path(&self, path: &ActorPath) -> Result { + Err(PulsingError::from(RuntimeError::named_actor_not_found( + path.as_str(), + ))) + } + } + let path = ActorPath::new("svc/missing").unwrap(); + let lazy_ref = ActorRef::lazy(path, Arc::new(FailingResolver)); + + let err = lazy_ref + .ask::(TestMsg { value: 1 }) + .await + .unwrap_err(); + assert!(err.to_string().contains("missing") || err.to_string().contains("not found")); + } + + #[tokio::test] + async fn test_invalidate_cache() { + let path = ActorPath::new("a/b").unwrap(); + struct CountResolver { + count: std::sync::atomic::AtomicU32, + } + #[async_trait::async_trait] + impl ActorResolver for CountResolver { + async fn resolve_path(&self, _path: &ActorPath) -> Result { + let (tx, rx) = mpsc::channel(16); + let c = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let id = ActorId::new(c as u128); + let ref_ = ActorRef::local(id, tx); + drop(rx); + Ok(ref_) + } + } + let resolver = Arc::new(CountResolver { + count: std::sync::atomic::AtomicU32::new(0), + }); + let lazy_ref = ActorRef::lazy(path.clone(), resolver); + lazy_ref.invalidate_cache().await; + lazy_ref.invalidate_cache().await; + } } diff --git a/crates/pulsing-actor/src/actor/traits.rs b/crates/pulsing-actor/src/actor/traits.rs index 9109ad9e1..6c1594969 100644 --- a/crates/pulsing-actor/src/actor/traits.rs +++ b/crates/pulsing-actor/src/actor/traits.rs @@ -1,5 +1,6 @@ //! Core actor traits and types. +use crate::error::{PulsingError, Result, RuntimeError}; use async_trait::async_trait; use futures::Stream; use serde::{de::DeserializeOwned, Deserialize, Serialize}; @@ -96,36 +97,36 @@ pub enum Format { impl Format { /// Parse data using this format - pub fn parse(&self, data: &[u8]) -> anyhow::Result { + pub fn parse(&self, data: &[u8]) -> Result { + let to_err = |e: &(dyn std::error::Error + '_)| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + }; match self { - Format::Bincode => Ok(bincode::deserialize(data)?), - Format::Json => Ok(serde_json::from_slice(data)?), - Format::Auto => { - // Try JSON first for Python compatibility, then bincode - match serde_json::from_slice(data) { - Ok(value) => Ok(value), - Err(_) => Ok(bincode::deserialize(data)?), - } - } + Format::Bincode => bincode::deserialize(data).map_err(|e| to_err(&e)), + Format::Json => serde_json::from_slice(data).map_err(|e| to_err(&e)), + Format::Auto => match serde_json::from_slice(data) { + Ok(value) => Ok(value), + Err(_) => bincode::deserialize(data).map_err(|e| to_err(&e)), + }, } } /// Serialize data using this format #[allow(dead_code)] - pub fn serialize(&self, value: &T) -> anyhow::Result> { + pub fn serialize(&self, value: &T) -> Result> { + let to_err = |e: &(dyn std::error::Error + '_)| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + }; match self { - Format::Bincode => Ok(bincode::serialize(value)?), - Format::Json => Ok(serde_json::to_vec(value)?), - Format::Auto => { - // Default to bincode for Auto serialization - Ok(bincode::serialize(value)?) - } + Format::Bincode => bincode::serialize(value).map_err(|e| to_err(&e)), + Format::Json => serde_json::to_vec(value).map_err(|e| to_err(&e)), + Format::Auto => bincode::serialize(value).map_err(|e| to_err(&e)), } } } /// Message stream type (stream of Single messages). -pub type MessageStream = Pin> + Send>>; +pub type MessageStream = Pin> + Send>>; /// Unified message type for both requests and responses. pub enum Message { @@ -147,31 +148,38 @@ impl Message { } } - pub fn pack(msg: &M) -> anyhow::Result { - Ok(Message::Single { - msg_type: std::any::type_name::().to_string(), - data: bincode::serialize(msg)?, - }) + pub fn pack(msg: &M) -> Result { + bincode::serialize(msg) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string()))) + .map(|data| Message::Single { + msg_type: std::any::type_name::().to_string(), + data, + }) } - pub fn unpack(self) -> anyhow::Result { + pub fn unpack(self) -> Result { match self { - Message::Single { data, .. } => Ok(bincode::deserialize(&data)?), - Message::Stream { .. } => Err(anyhow::anyhow!("Cannot unpack stream message")), + Message::Single { data, .. } => bincode::deserialize(&data) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string()))), + Message::Stream { .. } => Err(PulsingError::from(RuntimeError::Other( + "Cannot unpack stream message".into(), + ))), } } /// Parse message data with auto-detection (JSON first, then bincode) - pub fn parse(&self) -> anyhow::Result { + pub fn parse(&self) -> Result { match self { Message::Single { data, .. } => Format::Auto.parse(data), - Message::Stream { .. } => Err(anyhow::anyhow!("Cannot parse stream message")), + Message::Stream { .. } => Err(PulsingError::from(RuntimeError::Other( + "Cannot parse stream message".into(), + ))), } } pub fn from_channel( default_msg_type: impl Into, - rx: mpsc::Receiver>, + rx: mpsc::Receiver>, ) -> Self { let stream = tokio_stream::wrappers::ReceiverStream::new(rx); Message::Stream { @@ -182,7 +190,7 @@ impl Message { pub fn stream(default_msg_type: impl Into, stream: S) -> Self where - S: Stream> + Send + 'static, + S: Stream> + Send + 'static, { Message::Stream { default_msg_type: default_msg_type.into(), @@ -245,12 +253,12 @@ pub trait Actor: Send + Sync + 'static { } /// Called when the actor starts. - async fn on_start(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, _ctx: &mut ActorContext) -> Result<()> { Ok(()) } /// Called when the actor stops. - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> Result<()> { Ok(()) } @@ -298,12 +306,12 @@ pub trait Actor: Send + Sync + 'static { /// Message::pack(&sum) /// } /// ``` - async fn receive(&mut self, msg: Message, ctx: &mut ActorContext) -> anyhow::Result { - Err(anyhow::anyhow!( + async fn receive(&mut self, msg: Message, ctx: &mut ActorContext) -> Result { + Err(PulsingError::from(RuntimeError::Other(format!( "Actor {} does not handle message type: {}", ctx.id(), msg.msg_type() - )) + )))) } } @@ -410,7 +418,7 @@ mod tests { #[tokio::test] async fn test_message_server_streaming() { // Simulate a server streaming response with Message stream - let (tx, rx) = mpsc::channel::>(10); + let (tx, rx) = mpsc::channel::>(10); let msg = Message::from_channel("StreamResponse", rx); assert!(msg.is_stream()); @@ -445,7 +453,7 @@ mod tests { #[tokio::test] async fn test_message_client_streaming() { // Simulate a client streaming request - let (tx, rx) = mpsc::channel::>(10); + let (tx, rx) = mpsc::channel::>(10); let msg = Message::from_channel("StreamRequest", rx); tokio::spawn(async move { @@ -471,7 +479,7 @@ mod tests { #[tokio::test] async fn test_message_stream_heterogeneous() { // Test heterogeneous stream - different message types in one stream - let (tx, rx) = mpsc::channel::>(10); + let (tx, rx) = mpsc::channel::>(10); let msg = Message::from_channel("MixedStream", rx); tokio::spawn(async move { diff --git a/crates/pulsing-actor/src/behavior/core.rs b/crates/pulsing-actor/src/behavior/core.rs index d363e52e5..7e5ec5d06 100644 --- a/crates/pulsing-actor/src/behavior/core.rs +++ b/crates/pulsing-actor/src/behavior/core.rs @@ -1,6 +1,7 @@ use super::context::BehaviorContext; use super::reference::TypedRef; use crate::actor::{Actor, ActorContext, IntoActor, Message}; +use crate::error::{PulsingError, Result, RuntimeError}; use async_trait::async_trait; use futures::future::BoxFuture; use serde::{de::DeserializeOwned, Serialize}; @@ -106,15 +107,17 @@ impl Actor for BehaviorWrapper where M: Serialize + DeserializeOwned + Send + Sync + 'static, { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> Result { let typed_msg: M = msg.unpack()?; let mut behavior = self.behavior.lock().await; let mut ctx_guard = self.behavior_ctx.lock().await; - let ctx = ctx_guard - .as_mut() - .ok_or_else(|| anyhow::anyhow!("BehaviorContext not initialized"))?; + let ctx = ctx_guard.as_mut().ok_or_else(|| { + PulsingError::from(RuntimeError::Other( + "BehaviorContext not initialized".into(), + )) + })?; let action = behavior.receive(typed_msg, ctx).await; @@ -138,21 +141,23 @@ where _ctx.cancel_token().cancel(); - Err(anyhow::anyhow!( + Err(PulsingError::from(RuntimeError::Other(format!( "Actor stopped: {}", reason.unwrap_or_default() - )) + )))) } BehaviorAction::AlreadyStopped => { let actor_name = self.name.lock().await; let name = actor_name.as_deref().unwrap_or("unknown"); tracing::warn!(actor = %name, "Message received after actor stopped"); - Err(anyhow::anyhow!("Actor already stopped")) + Err(PulsingError::from(RuntimeError::Other( + "Actor already stopped".into(), + ))) } } } - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> Result<()> { // Get or derive the actor name let actor_name = ctx .named_path() diff --git a/crates/pulsing-actor/src/behavior/mod.rs b/crates/pulsing-actor/src/behavior/mod.rs index e23cc3b39..a865d680e 100644 --- a/crates/pulsing-actor/src/behavior/mod.rs +++ b/crates/pulsing-actor/src/behavior/mod.rs @@ -1,4 +1,181 @@ //! Behavior-based actor programming model. +//! +//! This module provides a functional, closure-based API for defining actors, +//! as an alternative to implementing the [`Actor`] trait directly. +//! +//! # Overview +//! +//! The Behavior API is ideal for: +//! - Simple stateful actors with clear state transitions +//! - Quick prototyping without defining new types +//! - Functional programming style +//! +//! # Key Types +//! +//! - [`stateful`] - Create a behavior with mutable state +//! - [`stateless`] - Create a behavior without state +//! - [`BehaviorAction`] - Control flow (continue, stop, become) +//! +//! # Examples +//! +//! ## Simple Counter (Stateful) +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::behavior::{stateful, BehaviorAction}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug, Clone)] +//! enum CounterMsg { +//! Increment(i32), +//! Decrement(i32), +//! Get, +//! } +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! struct Count(i32); +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! +//! // Create a stateful behavior +//! let counter = stateful(0i32, |count, msg: CounterMsg, _ctx| { +//! match msg { +//! CounterMsg::Increment(n) => { +//! *count += n; +//! BehaviorAction::Same +//! } +//! CounterMsg::Decrement(n) => { +//! *count -= n; +//! BehaviorAction::Same +//! } +//! CounterMsg::Get => { +//! println!("Current count: {}", count); +//! BehaviorAction::Same +//! } +//! } +//! }); +//! +//! let actor = system.spawn(counter).await?; +//! actor.tell(CounterMsg::Increment(10)).await?; +//! actor.tell(CounterMsg::Decrement(3)).await?; +//! actor.tell(CounterMsg::Get).await?; +//! +//! system.shutdown().await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Stateless Echo Server +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::behavior::{stateless, BehaviorAction}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug, Clone)] +//! struct Echo { text: String } +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! struct EchoReply { text: String } +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! +//! // Create a stateless behavior +//! let echo = stateless(|msg: Echo, _ctx| { +//! Box::pin(async move { +//! let reply = EchoReply { text: msg.text.clone() }; +//! // In real code, you would send the reply back +//! let _ = reply; +//! BehaviorAction::Same +//! }) +//! }); +//! +//! let actor = system.spawn(echo).await?; +//! actor.tell(Echo { text: "Hello!".into() }).await?; +//! +//! system.shutdown().await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## State Machine with Become +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::behavior::{stateful, BehaviorAction}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug, Clone)] +//! enum TrafficLightMsg { +//! Next, +//! GetState, +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! # let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! // Traffic light state machine +//! let green = stateful((), |_, msg: TrafficLightMsg, ctx| { +//! match msg { +//! TrafficLightMsg::Next => { +//! println!("Green -> Yellow"); +//! // Transition to yellow state +//! BehaviorAction::Same +//! } +//! TrafficLightMsg::GetState => { +//! println!("State: Green"); +//! BehaviorAction::Same +//! } +//! } +//! }); +//! +//! let _actor = system.spawn(green).await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Stopping a Behavior +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::behavior::{stateful, BehaviorAction}; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug, Clone)] +//! enum WorkerMsg { +//! DoWork, +//! Shutdown, +//! } +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! # let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! let worker = stateful(0i32, |count, msg: WorkerMsg, _ctx| { +//! match msg { +//! WorkerMsg::DoWork => { +//! *count += 1; +//! if *count >= 10 { +//! println!("Worker completed 10 tasks, stopping..."); +//! BehaviorAction::stop() +//! } else { +//! BehaviorAction::Same +//! } +//! } +//! WorkerMsg::Shutdown => { +//! println!("Worker shutting down..."); +//! BehaviorAction::stop_with_reason("Shutdown requested") +//! } +//! } +//! }); +//! +//! let _actor = system.spawn(worker).await?; +//! # Ok(()) +//! # } +//! ``` mod context; mod core; diff --git a/crates/pulsing-actor/src/behavior/reference.rs b/crates/pulsing-actor/src/behavior/reference.rs index 2c38d091c..236f48b63 100644 --- a/crates/pulsing-actor/src/behavior/reference.rs +++ b/crates/pulsing-actor/src/behavior/reference.rs @@ -70,27 +70,25 @@ where &self.name } - fn resolve(&self) -> anyhow::Result { + fn resolve(&self) -> crate::error::Result { match &self.mode { ResolutionMode::Direct(inner) => Ok(inner.clone()), ResolutionMode::Dynamic(system) => { system.local_actor_ref_by_name(&self.name).ok_or_else(|| { - anyhow::Error::from(PulsingError::from(RuntimeError::actor_not_found( - self.name.clone(), - ))) + PulsingError::from(RuntimeError::actor_not_found(self.name.clone())) }) } } } /// Send a message without waiting for response. - pub async fn tell(&self, msg: M) -> anyhow::Result<()> { + pub async fn tell(&self, msg: M) -> crate::error::Result<()> { let actor_ref = self.resolve()?; actor_ref.tell(msg).await } /// Send a message and wait for a response. - pub async fn ask(&self, msg: M) -> anyhow::Result + pub async fn ask(&self, msg: M) -> crate::error::Result where R: DeserializeOwned, { @@ -99,17 +97,22 @@ where } /// Send a message and wait for a response with timeout. - pub async fn ask_timeout(&self, msg: M, timeout: Duration) -> anyhow::Result + pub async fn ask_timeout(&self, msg: M, timeout: Duration) -> crate::error::Result where R: DeserializeOwned, { tokio::time::timeout(timeout, self.ask(msg)) .await - .map_err(|_| anyhow::anyhow!("Ask timeout after {:?}", timeout))? + .map_err(|_| { + PulsingError::from(RuntimeError::Other(format!( + "Ask timeout after {:?}", + timeout + ))) + })? } /// Get the underlying untyped ActorRef. - pub fn as_untyped(&self) -> anyhow::Result { + pub fn as_untyped(&self) -> crate::error::Result { self.resolve() } diff --git a/crates/pulsing-actor/src/cluster/backends/gossip.rs b/crates/pulsing-actor/src/cluster/backends/gossip.rs index 00827cf0e..d0ff8c60b 100644 --- a/crates/pulsing-actor/src/cluster/backends/gossip.rs +++ b/crates/pulsing-actor/src/cluster/backends/gossip.rs @@ -6,6 +6,7 @@ use crate::cluster::{ member::{MemberInfo, NamedActorInfo, NamedActorInstance}, GossipCluster, GossipConfig, }; +use crate::error::Result; use crate::transport::http2::Http2Transport; use async_trait::async_trait; use std::collections::HashMap; @@ -39,11 +40,11 @@ impl GossipBackend { #[async_trait] impl NamingBackend for GossipBackend { - async fn join(&self, seeds: Vec) -> anyhow::Result<()> { + async fn join(&self, seeds: Vec) -> Result<()> { self.cluster.join(seeds).await } - async fn leave(&self) -> anyhow::Result<()> { + async fn leave(&self) -> Result<()> { self.cluster.leave().await } diff --git a/crates/pulsing-actor/src/cluster/backends/head.rs b/crates/pulsing-actor/src/cluster/backends/head.rs index 505e0c5bc..16e0e8c8a 100644 --- a/crates/pulsing-actor/src/cluster/backends/head.rs +++ b/crates/pulsing-actor/src/cluster/backends/head.rs @@ -5,6 +5,7 @@ use crate::cluster::{ member::{MemberInfo, MemberStatus, NamedActorInfo, NamedActorInstance}, NamingBackend, }; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::transport::http2::{Http2Client, Http2Config}; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -345,13 +346,11 @@ impl HeadNodeBackend { } /// Handle node registration (head node only) - pub async fn handle_register_node( - &self, - node_id: NodeId, - addr: SocketAddr, - ) -> anyhow::Result<()> { + pub async fn handle_register_node(&self, node_id: NodeId, addr: SocketAddr) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -363,9 +362,11 @@ impl HeadNodeBackend { } /// Handle heartbeat (head node only) - pub async fn handle_heartbeat(&self, node_id: &NodeId) -> anyhow::Result<()> { + pub async fn handle_heartbeat(&self, node_id: &NodeId) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -373,10 +374,14 @@ impl HeadNodeBackend { if head_state.update_heartbeat(node_id) { Ok(()) } else { - Err(anyhow::anyhow!("Node not found: {}", node_id)) + Err(PulsingError::from(RuntimeError::node_not_found( + node_id.to_string(), + ))) } } else { - Err(anyhow::anyhow!("Invalid state")) + Err(PulsingError::from(RuntimeError::Other( + "Invalid state".into(), + ))) } } @@ -387,9 +392,11 @@ impl HeadNodeBackend { node_id: NodeId, actor_id: Option, metadata: HashMap, - ) -> anyhow::Result<()> { + ) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -405,9 +412,11 @@ impl HeadNodeBackend { &self, path: &ActorPath, node_id: &NodeId, - ) -> anyhow::Result<()> { + ) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -419,13 +428,11 @@ impl HeadNodeBackend { } /// Handle register actor (head node only) - pub async fn handle_register_actor( - &self, - actor_id: ActorId, - node_id: NodeId, - ) -> anyhow::Result<()> { + pub async fn handle_register_actor(&self, actor_id: ActorId, node_id: NodeId) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -441,9 +448,11 @@ impl HeadNodeBackend { &self, actor_id: &ActorId, node_id: &NodeId, - ) -> anyhow::Result<()> { + ) -> Result<()> { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let mut state = self.state.write().await; @@ -458,9 +467,11 @@ impl HeadNodeBackend { } /// Handle sync request (head node only) - returns current state - pub async fn handle_sync(&self) -> anyhow::Result { + pub async fn handle_sync(&self) -> Result { if !self.is_head() { - return Err(anyhow::anyhow!("Not a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Not a head node".into(), + ))); } let state = self.state.read().await; @@ -475,7 +486,9 @@ impl HeadNodeBackend { actors, }) } else { - Err(anyhow::anyhow!("Invalid state")) + Err(PulsingError::from(RuntimeError::Other( + "Invalid state".into(), + ))) } } @@ -485,29 +498,35 @@ impl HeadNodeBackend { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result { + ) -> Result { let head_addr = self.head_addr().ok_or_else(|| { - anyhow::anyhow!("Cannot make request to head node: this is a head node") + PulsingError::from(RuntimeError::Other( + "Cannot make request to head node: this is a head node".into(), + )) + })?; + let client = self.http_client.as_ref().ok_or_else(|| { + PulsingError::from(RuntimeError::Other("HTTP client not available".into())) })?; - let client = self - .http_client - .as_ref() - .ok_or_else(|| anyhow::anyhow!("HTTP client not available"))?; - let response_bytes = client.ask(head_addr, path, msg_type, payload).await?; - let result: T = bincode::deserialize(&response_bytes)?; + let response_bytes = client + .ask(head_addr, path, msg_type, payload) + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; + let result: T = bincode::deserialize(&response_bytes) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; Ok(result) } /// Send request to head node without response (worker mode only) - async fn tell_head(&self, path: &str, msg_type: &str, payload: Vec) -> anyhow::Result<()> { - let head_addr = self - .head_addr() - .ok_or_else(|| anyhow::anyhow!("Cannot send to head node: this is a head node"))?; - let client = self - .http_client - .as_ref() - .ok_or_else(|| anyhow::anyhow!("HTTP client not available"))?; + async fn tell_head(&self, path: &str, msg_type: &str, payload: Vec) -> Result<()> { + let head_addr = self.head_addr().ok_or_else(|| { + PulsingError::from(RuntimeError::Other( + "Cannot send to head node: this is a head node".into(), + )) + })?; + let client = self.http_client.as_ref().ok_or_else(|| { + PulsingError::from(RuntimeError::Other("HTTP client not available".into())) + })?; tracing::debug!( head_addr = %head_addr, @@ -516,13 +535,18 @@ impl HeadNodeBackend { "Sending tell to head node" ); - client.tell(head_addr, path, msg_type, payload).await + client + .tell(head_addr, path, msg_type, payload) + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string()))) } /// Sync with head node (worker mode only) - async fn sync_from_head(&self) -> anyhow::Result<()> { + async fn sync_from_head(&self) -> Result<()> { if self.is_head() { - return Err(anyhow::anyhow!("Cannot sync: this is a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Cannot sync: this is a head node".into(), + ))); } let sync: SyncResponse = self @@ -538,16 +562,19 @@ impl HeadNodeBackend { } /// Register with head node (worker mode only) - async fn register_with_head(&self) -> anyhow::Result<()> { + async fn register_with_head(&self) -> Result<()> { if self.is_head() { - return Err(anyhow::anyhow!("Cannot register: this is a head node")); + return Err(PulsingError::from(RuntimeError::Other( + "Cannot register: this is a head node".into(), + ))); } let req = RegisterNodeRequest { node_id: self.local_node, addr: self.local_addr, }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; self.tell_head("/cluster/head/register", "register_node", payload) .await?; @@ -562,17 +589,18 @@ impl HeadNodeBackend { } /// Send heartbeat to head node (worker mode only) - async fn send_heartbeat(&self) -> anyhow::Result<()> { + async fn send_heartbeat(&self) -> Result<()> { if self.is_head() { - return Err(anyhow::anyhow!( - "Cannot send heartbeat: this is a head node" - )); + return Err(PulsingError::from(RuntimeError::Other( + "Cannot send heartbeat: this is a head node".into(), + ))); } let req = HeartbeatRequest { node_id: self.local_node, }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; self.tell_head("/cluster/head/heartbeat", "heartbeat", payload) .await?; @@ -581,7 +609,7 @@ impl HeadNodeBackend { } /// Process pending sync operations (worker mode only) - async fn process_pending_sync(&self) -> anyhow::Result<()> { + async fn process_pending_sync(&self) -> Result<()> { if self.is_head() { return Ok(()); // No pending sync for head node } @@ -608,7 +636,9 @@ impl HeadNodeBackend { actor_id, metadata: metadata.clone(), }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req).map_err(|e| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + })?; if let Err(e) = self .tell_head( "/cluster/head/named_actor/register", @@ -634,7 +664,9 @@ impl HeadNodeBackend { path: path.clone(), node_id: self.local_node, }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req).map_err(|e| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + })?; if let Err(e) = self .tell_head( "/cluster/head/named_actor/unregister", @@ -656,7 +688,9 @@ impl HeadNodeBackend { actor_id, node_id: self.local_node, }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req).map_err(|e| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + })?; if let Err(e) = self .tell_head("/cluster/head/actor/register", "register_actor", payload) .await @@ -673,7 +707,9 @@ impl HeadNodeBackend { actor_id, node_id: self.local_node, }; - let payload = bincode::serialize(&req)?; + let payload = bincode::serialize(&req).map_err(|e| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + })?; if let Err(e) = self .tell_head( "/cluster/head/actor/unregister", @@ -707,7 +743,7 @@ impl NamingBackend for HeadNodeBackend { // Node Management // ======================================================================== - async fn join(&self, _seeds: Vec) -> anyhow::Result<()> { + async fn join(&self, _seeds: Vec) -> Result<()> { if self.is_head() { // Head node: register itself let mut state = self.state.write().await; @@ -725,7 +761,7 @@ impl NamingBackend for HeadNodeBackend { } } - async fn leave(&self) -> anyhow::Result<()> { + async fn leave(&self) -> Result<()> { if self.is_head() { // Head node: clear all registrations let mut state = self.state.write().await; diff --git a/crates/pulsing-actor/src/cluster/gossip.rs b/crates/pulsing-actor/src/cluster/gossip.rs index 0f98a09ee..4a4fc5963 100644 --- a/crates/pulsing-actor/src/cluster/gossip.rs +++ b/crates/pulsing-actor/src/cluster/gossip.rs @@ -6,6 +6,7 @@ use super::member::{ }; use super::swim::SwimConfig; use crate::actor::{ActorId, ActorPath, NodeId, StopReason}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::transport::http2::Http2Transport; use rand::prelude::IndexedRandom; use serde::{Deserialize, Serialize}; @@ -343,7 +344,7 @@ impl GossipCluster { } /// Join cluster via seed nodes - pub async fn join(&self, seed_addrs: Vec) -> anyhow::Result<()> { + pub async fn join(&self, seed_addrs: Vec) -> Result<()> { if seed_addrs.is_empty() { tracing::info!("No seed nodes provided, starting as first node"); return Ok(()); @@ -356,7 +357,8 @@ impl GossipCluster { from_addr: self.state.local_addr, current_epoch: self.state.current_epoch(), }; - let payload = bincode::serialize(&msg)?; + let payload = bincode::serialize(&msg) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; tracing::info!(seeds = ?seed_addrs, "Probing seed nodes"); @@ -411,7 +413,7 @@ impl GossipCluster { } /// Leave cluster gracefully - pub async fn leave(&self) -> anyhow::Result<()> { + pub async fn leave(&self) -> Result<()> { tracing::info!(node_id = %self.state.local_node, "Leaving cluster"); Ok(()) } @@ -421,7 +423,7 @@ impl GossipCluster { &self, msg: GossipMessage, peer_addr: SocketAddr, - ) -> anyhow::Result> { + ) -> Result> { match msg { GossipMessage::Meet { from, diff --git a/crates/pulsing-actor/src/cluster/naming.rs b/crates/pulsing-actor/src/cluster/naming.rs index 0c4f30207..371b53348 100644 --- a/crates/pulsing-actor/src/cluster/naming.rs +++ b/crates/pulsing-actor/src/cluster/naming.rs @@ -2,6 +2,7 @@ use crate::actor::{ActorId, ActorPath, NodeId, StopReason}; use crate::cluster::member::{MemberInfo, NamedActorInfo, NamedActorInstance}; +use crate::error::Result; use async_trait::async_trait; use std::collections::HashMap; use std::net::SocketAddr; @@ -10,9 +11,9 @@ use tokio_util::sync::CancellationToken; /// Trait for naming backends that provide cluster membership and actor discovery. #[async_trait] pub trait NamingBackend: Send + Sync { - async fn join(&self, seeds: Vec) -> anyhow::Result<()>; + async fn join(&self, seeds: Vec) -> Result<()>; - async fn leave(&self) -> anyhow::Result<()>; + async fn leave(&self) -> Result<()>; async fn all_members(&self) -> Vec; diff --git a/crates/pulsing-actor/src/error.rs b/crates/pulsing-actor/src/error.rs index eb9f2bee8..27c3dc6f6 100644 --- a/crates/pulsing-actor/src/error.rs +++ b/crates/pulsing-actor/src/error.rs @@ -15,6 +15,49 @@ //! - Timeout errors (operation timeouts) //! - Unsupported errors (unsupported operations) //! → Maps to Python: PulsingActorError (and subclasses) +//! +//! # Examples +//! +//! ## Error Classification +//! +//! ``` +//! use pulsing_actor::error::{PulsingError, RuntimeError, ActorError}; +//! +//! // Create a runtime error +//! let err = PulsingError::from(RuntimeError::ActorNotFound { +//! name: "my_actor".into(), +//! }); +//! +//! assert!(err.is_runtime()); +//! assert!(!err.is_actor()); +//! +//! // Create an actor error +//! let actor_err = PulsingError::from(ActorError::Timeout { +//! operation: "ask".into(), +//! duration_ms: 30000, +//! }); +//! +//! assert!(!actor_err.is_runtime()); +//! assert!(actor_err.is_actor()); +//! ``` +//! +//! ## Converting Errors +//! +//! ``` +//! use pulsing_actor::error::{PulsingError, RuntimeError}; +//! +//! fn do_something() -> Result<(), PulsingError> { +//! // Automatic conversion from RuntimeError +//! Err(RuntimeError::ActorNotFound { +//! name: "test".into(), +//! }.into()) +//! } +//! +//! match do_something() { +//! Err(e) => println!("Error: {}", e), +//! Ok(_) => unreachable!(), +//! } +//! ``` use thiserror::Error; @@ -47,25 +90,7 @@ impl PulsingError { } } -impl From for PulsingError { - fn from(err: anyhow::Error) -> Self { - // Try to downcast to known error types - if let Some(runtime_err) = err.downcast_ref::() { - return Self::Runtime(runtime_err.clone()); - } - if let Some(actor_err) = err.downcast_ref::() { - return Self::Actor(actor_err.clone()); - } - // Try to downcast to PulsingError itself - if let Some(pulsing_err) = err.downcast_ref::() { - return pulsing_err.clone(); - } - // Default to runtime error for unknown errors - Self::Runtime(RuntimeError::Other(err.to_string())) - } -} - -// Implement Clone for PulsingError to support downcast +// Implement Clone for PulsingError impl Clone for PulsingError { fn clone(&self) -> Self { match self { @@ -276,6 +301,20 @@ impl RuntimeError { Self::RequestTimeout { timeout_ms } } + /// Create a connection closed error + pub fn connection_closed(reason: impl Into) -> Self { + Self::ConnectionClosed { + reason: reason.into(), + } + } + + /// Create an invalid response error + pub fn invalid_response(reason: impl Into) -> Self { + Self::InvalidResponse { + reason: reason.into(), + } + } + /// Create a TLS error pub fn tls_error(reason: impl Into) -> Self { Self::TlsError { @@ -373,6 +412,51 @@ impl RuntimeError { pub fn io(err: std::io::Error) -> Self { Self::Io(err.to_string()) } + + /// Get the error kind as a snake_case string (for structured serialization) + pub fn kind(&self) -> &'static str { + match self { + Self::ActorNotFound { .. } => "actor_not_found", + Self::ActorAlreadyExists { .. } => "actor_already_exists", + Self::ActorNotLocal { .. } => "actor_not_local", + Self::ActorStopped { .. } => "actor_stopped", + Self::ActorMailboxFull { .. } => "actor_mailbox_full", + Self::InvalidActorPath { .. } => "invalid_actor_path", + Self::MessageTypeMismatch { .. } => "message_type_mismatch", + Self::ActorSpawnFailed { .. } => "actor_spawn_failed", + Self::ConnectionFailed { .. } => "connection_failed", + Self::ConnectionClosed { .. } => "connection_closed", + Self::RequestTimeout { .. } => "request_timeout", + Self::InvalidResponse { .. } => "invalid_response", + Self::TlsError { .. } => "tls_error", + Self::ProtocolError { .. } => "protocol_error", + Self::ClusterNotInitialized => "cluster_not_initialized", + Self::NodeNotFound { .. } => "node_not_found", + Self::NamedActorNotFound { .. } => "named_actor_not_found", + Self::NoHealthyInstances { .. } => "no_healthy_instances", + Self::JoinFailed { .. } => "join_failed", + Self::GossipError { .. } => "gossip_error", + Self::InvalidConfigValue { .. } => "invalid_config_value", + Self::MissingRequiredConfig { .. } => "missing_required_config", + Self::ConflictingConfig { .. } => "conflicting_config", + Self::InvalidAddress { .. } => "invalid_address", + Self::Io(_) => "io_error", + Self::Serialization(_) => "serialization_error", + Self::Other(_) => "other", + } + } + + /// Extract actor name if this error is related to a specific actor + pub fn actor_name(&self) -> Option<&str> { + match self { + Self::ActorNotFound { name } => Some(name), + Self::ActorAlreadyExists { name } => Some(name), + Self::ActorNotLocal { name } => Some(name), + Self::ActorStopped { name } => Some(name), + Self::ActorMailboxFull { name } => Some(name), + _ => None, + } + } } impl From for RuntimeError { @@ -499,6 +583,7 @@ pub type Result = std::result::Result; #[cfg(test)] mod tests { use super::*; + use std::error::Error; #[test] fn test_runtime_error_display() { @@ -562,4 +647,78 @@ mod tests { assert_eq!(err1, err2); assert_ne!(err1, err3); } + + #[test] + fn test_runtime_error_from_io_error() { + let io_err = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found"); + let runtime_err: RuntimeError = io_err.into(); + assert!(matches!(runtime_err, RuntimeError::Io(_))); + assert!(runtime_err.to_string().contains("file not found")); + } + + #[test] + fn test_runtime_error_kind_and_actor_name() { + let err = RuntimeError::actor_not_found("my-actor"); + assert_eq!(err.kind(), "actor_not_found"); + assert_eq!(err.actor_name(), Some("my-actor")); + + let err = RuntimeError::Other("generic".to_string()); + assert_eq!(err.kind(), "other"); + assert_eq!(err.actor_name(), None); + } + + #[test] + fn test_pulsing_error_is_runtime_is_actor() { + let runtime_err = PulsingError::from(RuntimeError::actor_not_found("x")); + assert!(runtime_err.is_runtime()); + assert!(!runtime_err.is_actor()); + + let actor_err = PulsingError::from(ActorError::business(400, "y", None)); + assert!(!actor_err.is_runtime()); + assert!(actor_err.is_actor()); + } + + /// Test error propagation with ? operator + fn propagate_result(ok: bool) -> Result<()> { + if ok { + Ok(()) + } else { + Err(RuntimeError::actor_not_found("test").into()) + } + } + + #[test] + fn test_error_propagation() { + assert!(propagate_result(true).is_ok()); + let err = propagate_result(false).unwrap_err(); + assert!(err.is_runtime()); + assert!(err.to_string().contains("test")); + } + + #[test] + fn test_runtime_error_resolve_helpers() { + let err = RuntimeError::no_healthy_instances("svc/echo"); + assert_eq!(err.kind(), "no_healthy_instances"); + assert!(err.to_string().to_lowercase().contains("svc/echo")); + + let err = RuntimeError::node_not_found("node-42"); + assert_eq!(err.kind(), "node_not_found"); + assert!(err.to_string().contains("node-42")); + + let err = RuntimeError::named_actor_not_found("a/b"); + assert_eq!(err.kind(), "named_actor_not_found"); + assert!(err.to_string().contains("a/b")); + + let err = RuntimeError::ClusterNotInitialized; + assert_eq!(err.kind(), "cluster_not_initialized"); + assert!(err.to_string().to_lowercase().contains("cluster")); + } + + #[test] + fn test_error_source_chain() { + let io_err = std::io::Error::new(std::io::ErrorKind::ConnectionRefused, "refused"); + let runtime_err: RuntimeError = io_err.into(); + let pulsing_err: PulsingError = runtime_err.into(); + assert!(pulsing_err.source().is_some()); + } } diff --git a/crates/pulsing-actor/src/lib.rs b/crates/pulsing-actor/src/lib.rs index 186be94bc..412265bc9 100644 --- a/crates/pulsing-actor/src/lib.rs +++ b/crates/pulsing-actor/src/lib.rs @@ -6,40 +6,132 @@ //! //! ## Quick Start //! -//! ```rust,ignore +//! Create your first actor and send messages: +//! +//! ```no_run //! use pulsing_actor::prelude::*; +//! use pulsing_actor::error::PulsingError; +//! use serde::{Deserialize, Serialize}; //! -//! #[derive(Serialize, Deserialize)] +//! // Define messages +//! #[derive(Serialize, Deserialize, Debug)] //! struct Ping { value: i32 } -//! #[derive(Serialize, Deserialize)] +//! +//! #[derive(Serialize, Deserialize, Debug)] //! struct Pong { result: i32 } //! +//! // Define actor state //! struct Counter { count: i32 } //! //! #[async_trait] //! impl Actor for Counter { -//! async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -//! -> anyhow::Result -//! { -//! if msg.msg_type().ends_with("Ping") { -//! let ping: Ping = msg.unpack()?; +//! async fn receive( +//! &mut self, +//! msg: Message, +//! _ctx: &mut ActorContext, +//! ) -> Result { +//! if let Ok(ping) = msg.unpack::() { //! self.count += ping.value; //! return Message::pack(&Pong { result: self.count }); //! } -//! Err(anyhow::anyhow!("Unknown message")) +//! Err(PulsingError::from( +//! pulsing_actor::error::RuntimeError::Other("Unknown message type".into()) +//! )) //! } //! } //! //! #[tokio::main] //! async fn main() -> anyhow::Result<()> { -//! let system = ActorSystem::builder().build().await?; -//! let actor_ref = system.spawn_named("services/counter", Counter { count: 0 }).await?; +//! // Create actor system +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! +//! // Spawn a named actor +//! let actor_ref = system +//! .spawn_named("services/counter", Counter { count: 0 }) +//! .await?; +//! +//! // Send message and await response //! let pong: Pong = actor_ref.ask(Ping { value: 42 }).await?; //! println!("Result: {}", pong.result); +//! +//! // Clean shutdown +//! system.shutdown().await?; +//! Ok(()) +//! } +//! ``` +//! +//! ## Using the Behavior API +//! +//! For simpler actors, use the Behavior API with closures: +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! use pulsing_actor::behavior::stateful; +//! use serde::{Deserialize, Serialize}; +//! +//! #[derive(Serialize, Deserialize, Debug, Clone)] +//! enum CounterMsg { +//! Increment(i32), +//! Get, +//! } +//! +//! #[derive(Serialize, Deserialize, Debug)] +//! struct Count(i32); +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! +//! // Create a stateful behavior with closure +//! let counter = stateful(0i32, |count, msg: CounterMsg, _ctx| { +//! match msg { +//! CounterMsg::Increment(n) => { +//! *count += n; +//! pulsing_actor::behavior::BehaviorAction::Same +//! } +//! CounterMsg::Get => { +//! // Return current count without changing state +//! let _ = count; +//! pulsing_actor::behavior::BehaviorAction::Same +//! } +//! } +//! }); +//! +//! let actor_ref = system.spawn(counter).await?; +//! actor_ref.tell(CounterMsg::Increment(10)).await?; +//! //! system.shutdown().await?; //! Ok(()) //! } //! ``` +//! +//! ## Cluster Mode +//! +//! Run actors across multiple nodes: +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! +//! #[tokio::main] +//! async fn main() -> anyhow::Result<()> { +//! // Node 1: Seed node +//! let addr: std::net::SocketAddr = "0.0.0.0:8000".parse()?; +//! let config = SystemConfig::with_addr(addr); +//! let system1 = ActorSystem::new(config).await?; +//! +//! // Node 2: Join the cluster +//! let addr: std::net::SocketAddr = "0.0.0.0:8001".parse()?; +//! let seed: std::net::SocketAddr = "127.0.0.1:8000".parse()?; +//! let config = SystemConfig::with_addr(addr) +//! .with_seeds(vec![seed]); +//! let system2 = ActorSystem::new(config).await?; +//! +//! // Actors can now communicate across nodes +//! println!("Cluster formed with 2 nodes"); +//! +//! Ok(()) +//! } +//! ``` pub mod actor; pub mod behavior; diff --git a/crates/pulsing-actor/src/supervision.rs b/crates/pulsing-actor/src/supervision.rs index 743910e4e..152cfdad2 100644 --- a/crates/pulsing-actor/src/supervision.rs +++ b/crates/pulsing-actor/src/supervision.rs @@ -13,7 +13,7 @@ pub enum RestartPolicy { Always, /// Restart the actor only if it failed (non-normal exit) OnFailure, - /// Never restart the actor (default) + /// Never restart the actor (default). Panic / 不可恢复错误时停止且不恢复 #[default] Never, } @@ -95,7 +95,7 @@ pub struct SupervisionSpec { pub policy: RestartPolicy, /// Backoff strategy pub backoff: BackoffStrategy, - /// Maximum number of restarts allowed + /// Maximum number of restarts allowed (used when policy is Always/OnFailure) pub max_restarts: u32, /// Time window for max_restarts (optional). /// If set, max_restarts applies only within this sliding window. diff --git a/crates/pulsing-actor/src/system/config.rs b/crates/pulsing-actor/src/system/config.rs index f7843f26e..b250d99b9 100644 --- a/crates/pulsing-actor/src/system/config.rs +++ b/crates/pulsing-actor/src/system/config.rs @@ -8,6 +8,7 @@ use crate::actor::{NodeId, DEFAULT_MAILBOX_SIZE}; use crate::cluster::{GossipConfig, HeadNodeConfig}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::policies::LoadBalancingPolicy; use crate::supervision::SupervisionSpec; use crate::transport::Http2Config; @@ -109,8 +110,11 @@ impl SystemConfig { /// The passphrase is used to derive a shared CA certificate, enabling /// automatic mutual TLS authentication. #[cfg(feature = "tls")] - pub fn with_tls(mut self, passphrase: &str) -> anyhow::Result { - self.http2_config = self.http2_config.with_tls(passphrase)?; + pub fn with_tls(mut self, passphrase: &str) -> Result { + self.http2_config = self + .http2_config + .with_tls(passphrase) + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; Ok(self) } @@ -227,9 +231,9 @@ impl std::error::Error for ConfigValidationError {} #[derive(Default)] pub struct ActorSystemBuilder { /// Bind address (stored as Result for deferred error handling) - addr: Option>, + addr: Option>, /// Seed nodes (stored as Results for deferred error handling) - seeds: Vec>, + seeds: Vec>, /// Mailbox capacity mailbox_capacity: Option, /// Gossip configuration @@ -239,7 +243,7 @@ pub struct ActorSystemBuilder { /// Head node mode is_head_node: bool, /// Head node address (if set, makes this a worker) - head_addr: Option>, + head_addr: Option>, /// Head node configuration head_node_config: Option, } @@ -271,12 +275,13 @@ impl ActorSystemBuilder { /// Enable TLS with passphrase #[cfg(feature = "tls")] - pub fn tls(mut self, passphrase: &str) -> anyhow::Result { + pub fn tls(mut self, passphrase: &str) -> Result { let http2_config = self .http2_config .take() .unwrap_or_default() - .with_tls(passphrase)?; + .with_tls(passphrase) + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; self.http2_config = Some(http2_config); Ok(self) } @@ -313,7 +318,7 @@ impl ActorSystemBuilder { /// Build the ActorSystem /// /// Returns an error if any address parsing or validation failed. - pub async fn build(self) -> anyhow::Result> { + pub async fn build(self) -> crate::error::Result> { let addr = Self::parse_optional_addr("bind address", self.addr)?.unwrap_or(DEFAULT_BIND_ADDR); @@ -339,46 +344,47 @@ impl ActorSystemBuilder { fn parse_optional_addr( label: &str, - input: Option>, - ) -> anyhow::Result> { + input: Option>, + ) -> Result> { match input { Some(Ok(addr)) => Ok(Some(addr)), - Some(Err(invalid)) => Err(anyhow::anyhow!("Invalid {}: {}", label, invalid)), + Some(Err(invalid)) => Err(PulsingError::from(RuntimeError::Other(format!( + "Invalid {}: {}", + label, invalid + )))), None => Ok(None), } } fn parse_addr_list( label: &str, - seeds: Vec>, - ) -> anyhow::Result> { + seeds: Vec>, + ) -> Result> { let mut addrs = Vec::with_capacity(seeds.len()); for (i, seed) in seeds.into_iter().enumerate() { match seed { Ok(addr) => addrs.push(addr), Err(invalid) => { - return Err(anyhow::anyhow!( + return Err(PulsingError::from(RuntimeError::Other(format!( "Invalid {} at index {}: {}", - label, - i, - invalid - )); + label, i, invalid + )))); } } } Ok(addrs) } - fn validate_config(config: &SystemConfig) -> anyhow::Result<()> { + fn validate_config(config: &SystemConfig) -> Result<()> { let errors = config.validate(); if errors.is_empty() { return Ok(()); } let error_msgs: Vec = errors.iter().map(|e| e.to_string()).collect(); - Err(anyhow::anyhow!( + Err(PulsingError::from(RuntimeError::Other(format!( "Configuration validation failed:\n - {}", error_msgs.join("\n - ") - )) + )))) } } @@ -509,10 +515,118 @@ mod tests { let err = ConfigValidationError::ConflictingHeadNodeConfig; assert!(err.to_string().contains("head_node")); } + + // --- 配置解析 --- + + #[test] + fn test_config_with_seeds() { + let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap(); + let config = SystemConfig::standalone().with_seeds(vec![addr]); + assert_eq!(config.seed_nodes.len(), 1); + assert_eq!(config.seed_nodes[0], addr); + } + + #[test] + fn test_config_with_head_node_and_head_addr() { + let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap(); + let config = SystemConfig::standalone().with_head_node(); + assert!(config.is_head_node); + assert!(config.head_addr.is_none()); + + let config = SystemConfig::standalone().with_head_addr(addr); + assert!(!config.is_head_node); + assert_eq!(config.head_addr, Some(addr)); + } + + #[tokio::test] + async fn test_builder_invalid_addr_parse() { + let result = ActorSystemBuilder::default() + .addr("not-a-valid-address") + .build() + .await; + let err = match result { + Ok(_) => panic!("expected build to fail"), + Err(e) => e, + }; + assert!(err.to_string().to_lowercase().contains("invalid")); + } + + #[tokio::test] + async fn test_builder_invalid_seed_parse() { + let result = ActorSystemBuilder::default() + .seeds(vec!["127.0.0.1:0", "invalid-seed"]) + .build() + .await; + let err = match result { + Ok(_) => panic!("expected build to fail"), + Err(e) => e, + }; + let msg = err.to_string(); + assert!(msg.contains("Invalid") && msg.contains("index")); + } + + #[tokio::test] + async fn test_builder_validation_mailbox_too_small() { + let result = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .mailbox_capacity(1) + .build() + .await; + let err = match result { + Ok(_) => panic!("expected validation to fail"), + Err(e) => e, + }; + assert!(err.to_string().to_lowercase().contains("mailbox")); + } + + #[tokio::test] + async fn test_builder_validation_mailbox_too_large() { + let result = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .mailbox_capacity(10_000_000) + .build() + .await; + let err = match result { + Ok(_) => panic!("expected validation to fail"), + Err(e) => e, + }; + assert!(err.to_string().to_lowercase().contains("mailbox")); + } + + #[test] + fn test_validation_multiple_errors() { + let addr: SocketAddr = "127.0.0.1:9000".parse().unwrap(); + let config = SystemConfig { + default_mailbox_capacity: 1, + is_head_node: true, + head_addr: Some(addr), + ..Default::default() + }; + let errors = config.validate(); + assert_eq!(errors.len(), 2); + let has_mailbox = errors + .iter() + .any(|e| matches!(e, ConfigValidationError::MailboxTooSmall { .. })); + let has_conflict = errors + .iter() + .any(|e| matches!(e, ConfigValidationError::ConflictingHeadNodeConfig)); + assert!(has_mailbox); + assert!(has_conflict); + } + + #[tokio::test] + async fn test_builder_valid_build() { + let result = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .build() + .await; + let system = result.unwrap(); + system.shutdown().await.unwrap(); + } } /// Helper type for flexible address input (defers parsing errors) -pub struct AddrInput(Result); +pub struct AddrInput(std::result::Result); impl From for AddrInput { fn from(addr: SocketAddr) -> Self { diff --git a/crates/pulsing-actor/src/system/handler.rs b/crates/pulsing-actor/src/system/handler.rs index 8fc4b148f..5520dd143 100644 --- a/crates/pulsing-actor/src/system/handler.rs +++ b/crates/pulsing-actor/src/system/handler.rs @@ -1,13 +1,12 @@ //! HTTP/2 message handler for the actor system -use super::handle::LocalActorHandle; use crate::actor::{ActorId, ActorPath, Envelope, Message, NodeId}; use crate::cluster::backends::{RegisterActorRequest, UnregisterActorRequest}; use crate::cluster::{GossipBackend, GossipMessage, HeadNodeBackend, NamingBackend}; -use crate::error::{PulsingError, RuntimeError}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::metrics::{metrics, SystemMetrics as PrometheusMetrics}; +use crate::system::registry::ActorRegistry; use crate::transport::Http2ServerHandler; -use dashmap::DashMap; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::net::SocketAddr; @@ -15,79 +14,60 @@ use std::sync::atomic::Ordering; use std::sync::Arc; use tokio::sync::{mpsc, RwLock}; -/// Unified message handler for HTTP/2 transport +/// Unified message handler for HTTP/2 transport. +/// +/// Uses [`ActorRegistry`] for actor lookup instead of raw DashMaps. pub(crate) struct SystemMessageHandler { node_id: NodeId, - /// Local actors indexed by ActorId - local_actors: Arc>, - /// Actor name to ActorId mapping - actor_names: Arc>, - named_actor_paths: Arc>, + /// Actor registry for local actor management + registry: Arc, + /// Cluster backend cluster: Arc>>>, } impl SystemMessageHandler { pub fn new( node_id: NodeId, - local_actors: Arc>, - actor_names: Arc>, - named_actor_paths: Arc>, + registry: Arc, cluster: Arc>>>, ) -> Self { Self { node_id, - local_actors, - actor_names, - named_actor_paths, + registry, cluster, } } - /// Find actor sender by name or ActorId (O(1) lookup) - fn find_actor_sender(&self, actor_name: &str) -> anyhow::Result> { - // First try by name -> ActorId -> handle - if let Some(actor_id) = self.actor_names.get(actor_name) { - if let Some(handle) = self.local_actors.get(actor_id.value()) { - return Ok(handle.sender.clone()); - } - } - - // Then try parsing as ActorId (UUID format) - if let Ok(uuid) = uuid::Uuid::parse_str(actor_name) { - let actor_id = ActorId::new(uuid.as_u128()); - if let Some(handle) = self.local_actors.get(&actor_id) { - return Ok(handle.sender.clone()); - } - } - - Err(anyhow::Error::from(PulsingError::from( - RuntimeError::actor_not_found(actor_name.to_string()), - ))) + /// Find actor sender by name or ActorId (O(1) lookup via registry) + fn find_actor_sender(&self, actor_name: &str) -> Result> { + self.registry.find_actor_sender(actor_name).ok_or_else(|| { + PulsingError::from(RuntimeError::actor_not_found(actor_name.to_string())) + }) } /// Dispatch a message to an actor (ask pattern) - async fn dispatch_message(&self, path: &str, msg: Message) -> anyhow::Result { + async fn dispatch_message(&self, path: &str, msg: Message) -> Result { if let Some(actor_name) = path.strip_prefix("/actors/") { self.send_to_local_actor(actor_name, msg).await } else if let Some(named_path) = path.strip_prefix("/named/") { self.send_to_named_actor(named_path, msg).await } else { - Err(anyhow::anyhow!("Invalid path: {}", path)) + Err(PulsingError::from(RuntimeError::invalid_actor_path(path))) } } /// Dispatch a fire-and-forget message - async fn dispatch_tell(&self, path: &str, msg: Message) -> anyhow::Result<()> { + async fn dispatch_tell(&self, path: &str, msg: Message) -> Result<()> { if let Some(actor_name) = path.strip_prefix("/actors/") { self.tell_local_actor(actor_name, msg).await } else if let Some(named_path) = path.strip_prefix("/named/") { self.tell_named_actor(named_path, msg).await } else { - Err(anyhow::anyhow!("Invalid path: {}", path)) + Err(PulsingError::from(RuntimeError::invalid_actor_path(path))) } } - async fn send_to_local_actor(&self, actor_name: &str, msg: Message) -> anyhow::Result { + async fn send_to_local_actor(&self, actor_name: &str, msg: Message) -> Result { let sender = self.find_actor_sender(actor_name)?; let (tx, rx) = tokio::sync::oneshot::channel(); @@ -96,39 +76,38 @@ impl SystemMessageHandler { sender .send(envelope) .await - .map_err(|_| anyhow::anyhow!("Actor mailbox closed"))?; + .map_err(|_| PulsingError::from(RuntimeError::actor_stopped(actor_name)))?; - rx.await.map_err(|_| anyhow::anyhow!("Actor dropped"))? + rx.await + .map_err(|_| PulsingError::from(RuntimeError::actor_stopped(actor_name)))? } - async fn tell_local_actor(&self, actor_name: &str, msg: Message) -> anyhow::Result<()> { + async fn tell_local_actor(&self, actor_name: &str, msg: Message) -> Result<()> { let sender = self.find_actor_sender(actor_name)?; let envelope = Envelope::tell(msg); sender .send(envelope) .await - .map_err(|_| anyhow::anyhow!("Actor mailbox closed"))?; + .map_err(|_| PulsingError::from(RuntimeError::actor_stopped(actor_name)))?; Ok(()) } - async fn send_to_named_actor(&self, path: &str, msg: Message) -> anyhow::Result { + async fn send_to_named_actor(&self, path: &str, msg: Message) -> Result { let actor_name = self - .named_actor_paths - .get(path) - .ok_or_else(|| anyhow::anyhow!("Named actor not found: {}", path))? - .clone(); + .registry + .get_actor_name_by_path(path) + .ok_or_else(|| PulsingError::from(RuntimeError::named_actor_not_found(path)))?; self.send_to_local_actor(&actor_name, msg).await } - async fn tell_named_actor(&self, path: &str, msg: Message) -> anyhow::Result<()> { + async fn tell_named_actor(&self, path: &str, msg: Message) -> Result<()> { let actor_name = self - .named_actor_paths - .get(path) - .ok_or_else(|| anyhow::anyhow!("Named actor not found: {}", path))? - .clone(); + .registry + .get_actor_name_by_path(path) + .ok_or_else(|| PulsingError::from(RuntimeError::named_actor_not_found(path)))?; self.tell_local_actor(&actor_name, msg).await } @@ -137,7 +116,7 @@ impl SystemMessageHandler { #[async_trait::async_trait] impl Http2ServerHandler for SystemMessageHandler { /// Unified message handler - accepts Message (Single or Stream), returns Message - async fn handle_message_full(&self, path: &str, msg: Message) -> anyhow::Result { + async fn handle_message_full(&self, path: &str, msg: Message) -> Result { self.dispatch_message(path, msg).await } @@ -147,17 +126,12 @@ impl Http2ServerHandler for SystemMessageHandler { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result { + ) -> Result { let msg = Message::single(msg_type, payload); self.dispatch_message(path, msg).await } - async fn handle_tell( - &self, - path: &str, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result<()> { + async fn handle_tell(&self, path: &str, msg_type: &str, payload: Vec) -> Result<()> { let msg = Message::single(msg_type, payload); self.dispatch_tell(path, msg).await } @@ -166,15 +140,18 @@ impl Http2ServerHandler for SystemMessageHandler { &self, payload: Vec, peer_addr: SocketAddr, - ) -> anyhow::Result>> { + ) -> Result>> { let cluster_guard = self.cluster.read().await; if let Some(backend) = cluster_guard.as_ref() { // Try to downcast to GossipBackend to access handle_gossip if let Some(gossip_backend) = backend.as_any().downcast_ref::() { - let msg: GossipMessage = bincode::deserialize(&payload)?; + let msg: GossipMessage = bincode::deserialize(&payload) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; let response = gossip_backend.inner().handle_gossip(msg, peer_addr).await?; if let Some(resp) = response { - Ok(Some(bincode::serialize(&resp)?)) + Ok(Some(bincode::serialize(&resp).map_err(|e| { + PulsingError::from(RuntimeError::Serialization(e.to_string())) + })?)) } else { Ok(None) } @@ -190,12 +167,13 @@ impl Http2ServerHandler for SystemMessageHandler { async fn health_check(&self) -> serde_json::Value { // Collect local actors info let mut actors = Vec::new(); - for entry in self.local_actors.iter() { + for entry in self.registry.iter_actors() { let local_id = *entry.key(); let handle = entry.value(); // Find name from actor_names (reverse lookup) let name = self + .registry .actor_names .iter() .find(|e| *e.value() == local_id) @@ -218,8 +196,8 @@ impl Http2ServerHandler for SystemMessageHandler { // Collect named actors info let named_actors: Vec<_> = self - .named_actor_paths - .iter() + .registry + .iter_named_paths() .map(|e| { serde_json::json!({ "path": e.key().clone(), @@ -272,16 +250,16 @@ impl Http2ServerHandler for SystemMessageHandler { // Count messages from local actors let mut total_messages: u64 = 0; - for entry in self.local_actors.iter() { + for entry in self.registry.iter_actors() { total_messages += entry.value().stats.message_count.load(Ordering::Relaxed); } // Build system metrics let system_metrics = PrometheusMetrics { node_id: self.node_id.0, - actors_count: self.local_actors.len(), + actors_count: self.registry.actor_count(), messages_total: total_messages, - actors_created: self.local_actors.len() as u64, + actors_created: self.registry.actor_count() as u64, actors_stopped: 0, cluster_members, }; @@ -378,8 +356,7 @@ impl Http2ServerHandler for SystemMessageHandler { path: &str, method: &str, body: Vec, - ) -> anyhow::Result>> { - // Call the implementation method + ) -> Result>> { SystemMessageHandler::handle_head_api_impl(self, path, method, body).await } } @@ -395,81 +372,80 @@ impl SystemMessageHandler { path: &str, method: &str, body: Vec, - ) -> anyhow::Result>> { + ) -> Result>> { let cluster_guard = self.cluster.read().await; let backend = cluster_guard .as_ref() - .ok_or_else(|| anyhow::anyhow!("Cluster backend not available"))?; + .ok_or_else(|| PulsingError::from(RuntimeError::ClusterNotInitialized))?; - // Try to downcast to HeadNodeBackend let head_backend = match backend.as_any().downcast_ref::() { Some(b) if b.is_head() => b, _ => return Ok(None), }; + let ser = + |e: bincode::Error| PulsingError::from(RuntimeError::Serialization(e.to_string())); + match (method, path) { ("POST", "/cluster/head/register") => { - let req: RegisterNodeRequest = bincode::deserialize(&body)?; + let req: RegisterNodeRequest = bincode::deserialize(&body).map_err(ser)?; head_backend .handle_register_node(req.node_id, req.addr) .await?; Ok(Some(Vec::new())) } ("POST", "/cluster/head/heartbeat") => { - let req: HeartbeatRequest = bincode::deserialize(&body)?; + let req: HeartbeatRequest = bincode::deserialize(&body).map_err(ser)?; head_backend.handle_heartbeat(&req.node_id).await?; - Ok(Some(Vec::new())) // Return empty body for success + Ok(Some(Vec::new())) } ("POST", "/cluster/head/named_actor/register") => { - let req: RegisterNamedActorRequest = bincode::deserialize(&body)?; + let req: RegisterNamedActorRequest = bincode::deserialize(&body).map_err(ser)?; head_backend .handle_register_named_actor(req.path, req.node_id, req.actor_id, req.metadata) .await?; - Ok(Some(Vec::new())) // Return empty body for success + Ok(Some(Vec::new())) } ("POST", "/cluster/head/named_actor/unregister") => { - let req: UnregisterNamedActorRequest = bincode::deserialize(&body)?; + let req: UnregisterNamedActorRequest = bincode::deserialize(&body).map_err(ser)?; head_backend .handle_unregister_named_actor(&req.path, &req.node_id) .await?; - Ok(Some(Vec::new())) // Return empty body for success + Ok(Some(Vec::new())) } ("GET", "/cluster/head/members") | ("POST", "/cluster/head/members") => { - // Support both GET and POST (POST is used by Http2Client) let members = head_backend.all_members().await; - let body = bincode::serialize(&members)?; + let body = bincode::serialize(&members).map_err(ser)?; Ok(Some(body)) } ("GET", "/cluster/head/named_actors") | ("POST", "/cluster/head/named_actors") => { - // Support both GET and POST (POST is used by Http2Client) let named_actors = head_backend.all_named_actors().await; - let body = bincode::serialize(&named_actors)?; + let body = bincode::serialize(&named_actors).map_err(ser)?; Ok(Some(body)) } ("GET", "/cluster/head/sync") | ("POST", "/cluster/head/sync") => { let sync = head_backend.handle_sync().await?; - let body = bincode::serialize(&sync)?; + let body = bincode::serialize(&sync).map_err(ser)?; Ok(Some(body)) } ("POST", "/cluster/head/actor/register") => { - let req: RegisterActorRequest = bincode::deserialize(&body)?; + let req: RegisterActorRequest = bincode::deserialize(&body).map_err(ser)?; head_backend .handle_register_actor(req.actor_id, req.node_id) .await?; Ok(Some(Vec::new())) } ("POST", "/cluster/head/actor/unregister") => { - let req: UnregisterActorRequest = bincode::deserialize(&body)?; + let req: UnregisterActorRequest = bincode::deserialize(&body).map_err(ser)?; head_backend .handle_unregister_actor(&req.actor_id, &req.node_id) .await?; Ok(Some(Vec::new())) } - _ => Err(anyhow::anyhow!( + _ => Err(PulsingError::from(RuntimeError::Other(format!( "Unknown head API endpoint: {} {}", - method, - path - )), + method, path + )))), } } } @@ -502,3 +478,141 @@ struct UnregisterNamedActorRequest { path: ActorPath, node_id: NodeId, } + +#[cfg(test)] +mod tests { + use super::*; + use crate::system::handle::{ActorStats, LocalActorHandle}; + use std::collections::HashMap; + + fn make_mock_actor( + _actor_id: ActorId, + ) -> (mpsc::Sender, tokio::task::JoinHandle<()>) { + let (tx, mut rx) = mpsc::channel::(8); + let join = tokio::spawn(async move { + while let Some(env) = rx.recv().await { + let (_, responder) = env.into_parts(); + responder.send(Ok(Message::single("pong", vec![1, 2, 3]))); + } + }); + (tx, join) + } + + #[tokio::test] + async fn test_dispatch_message_invalid_path() { + let registry = Arc::new(ActorRegistry::new()); + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + let msg = Message::single("ping", vec![]); + + let err = handler + .handle_message_full("/invalid/path", msg) + .await + .unwrap_err(); + assert!(err.is_runtime()); + assert!(err.to_string().to_lowercase().contains("path")); + } + + #[tokio::test] + async fn test_dispatch_message_actor_not_found() { + let registry = Arc::new(ActorRegistry::new()); + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + let msg = Message::single("ping", vec![]); + + let err = handler + .handle_message_full("/actors/missing-actor", msg) + .await + .unwrap_err(); + assert!(err.is_runtime()); + assert!(err.to_string().contains("missing-actor")); + } + + #[tokio::test] + async fn test_dispatch_message_named_actor_not_found() { + let registry = Arc::new(ActorRegistry::new()); + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + let msg = Message::single("ping", vec![]); + + let err = handler + .handle_message_full("/named/services/foo", msg) + .await + .unwrap_err(); + assert!(err.is_runtime()); + assert!( + err.to_string().contains("named_actor_not_found") + || err.to_string().contains("services/foo") + ); + } + + #[tokio::test] + async fn test_dispatch_message_success() { + let actor_id = ActorId::generate(); + let (sender, join_handle) = make_mock_actor(actor_id); + let handle = LocalActorHandle { + sender, + join_handle, + cancel_token: tokio_util::sync::CancellationToken::new(), + stats: Arc::new(ActorStats::default()), + metadata: HashMap::new(), + named_path: None, + actor_id, + }; + + let registry = Arc::new(ActorRegistry::new()); + registry.register_actor(actor_id, handle); + registry.register_name("test-actor".to_string(), actor_id); + + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + let msg = Message::single("ping", vec![]); + + let response = handler + .handle_message_full("/actors/test-actor", msg) + .await + .unwrap(); + assert_eq!(response.msg_type(), "pong"); + } + + #[tokio::test] + async fn test_handle_message_simple_routing() { + let actor_id = ActorId::generate(); + let (sender, join_handle) = make_mock_actor(actor_id); + let handle = LocalActorHandle { + sender, + join_handle, + cancel_token: tokio_util::sync::CancellationToken::new(), + stats: Arc::new(ActorStats::default()), + metadata: HashMap::new(), + named_path: None, + actor_id, + }; + + let registry = Arc::new(ActorRegistry::new()); + registry.register_actor(actor_id, handle); + registry.register_name("simple".to_string(), actor_id); + + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + + let response = handler + .handle_message_simple("/actors/simple", "req", vec![1, 2, 3]) + .await + .unwrap(); + assert_eq!(response.msg_type(), "pong"); + } + + #[tokio::test] + async fn test_dispatch_tell_actor_not_found() { + let registry = Arc::new(ActorRegistry::new()); + let cluster = Arc::new(RwLock::new(None)); + let handler = SystemMessageHandler::new(NodeId::generate(), registry, cluster); + + let err = handler + .handle_tell("/actors/missing", "msg", vec![]) + .await + .unwrap_err(); + assert!(err.is_runtime()); + } +} diff --git a/crates/pulsing-actor/src/system/lifecycle.rs b/crates/pulsing-actor/src/system/lifecycle.rs index 628e9a8c7..e0dfaa91b 100644 --- a/crates/pulsing-actor/src/system/lifecycle.rs +++ b/crates/pulsing-actor/src/system/lifecycle.rs @@ -4,6 +4,7 @@ //! for graceful lifecycle management. use crate::actor::{ActorPath, StopReason}; +use crate::error::Result; use crate::system::ActorSystem; use std::time::Duration; use tokio_util::sync::CancellationToken; @@ -17,7 +18,7 @@ impl ActorSystem { /// This method first signals the actor to stop via its cancellation token, /// waits for it to finish (with timeout), then performs cleanup. /// If the actor doesn't stop within the timeout, it will be forcefully aborted. - pub async fn stop(&self, name: impl AsRef) -> anyhow::Result<()> { + pub async fn stop(&self, name: impl AsRef) -> Result<()> { self.stop_with_reason(name, StopReason::Killed).await } @@ -25,18 +26,14 @@ impl ActorSystem { /// /// Note: If the name doesn't contain a "/" and no actor is found with the exact name, /// it will try with the "actors/" prefix (for Python compatibility). - pub async fn stop_with_reason( - &self, - name: impl AsRef, - reason: StopReason, - ) -> anyhow::Result<()> { + pub async fn stop_with_reason(&self, name: impl AsRef, reason: StopReason) -> Result<()> { let name = name.as_ref(); - let actual_name = if self.actor_names.contains_key(name) { + let actual_name = if self.registry.has_name(name) { name.to_string() } else if !name.contains('/') { let prefixed = format!("actors/{}", name); - if self.actor_names.contains_key(&prefixed) { + if self.registry.has_name(&prefixed) { prefixed } else { name.to_string() @@ -45,8 +42,8 @@ impl ActorSystem { name.to_string() }; - if let Some((_, local_id)) = self.actor_names.remove(&actual_name) { - if let Some((_, handle)) = self.local_actors.remove(&local_id) { + if let Some((_, local_id)) = self.registry.remove_by_name(&actual_name) { + if let Some((_, handle)) = self.registry.remove_handle(&local_id) { let named_path = handle.named_path.clone(); self.stop_local_actor( &actual_name, @@ -63,7 +60,7 @@ impl ActorSystem { } /// Stop a named actor by path - pub async fn stop_named(&self, path: &crate::actor::ActorPath) -> anyhow::Result<()> { + pub async fn stop_named(&self, path: &crate::actor::ActorPath) -> Result<()> { self.stop_named_with_reason(path, StopReason::Killed).await } @@ -72,15 +69,12 @@ impl ActorSystem { &self, path: &crate::actor::ActorPath, reason: StopReason, - ) -> anyhow::Result<()> { + ) -> Result<()> { let path_key = path.as_str(); - if let Some(actor_name_ref) = self.named_actor_paths.get(&path_key) { - let actor_name = actor_name_ref.clone(); - drop(actor_name_ref); - - if let Some((_, local_id)) = self.actor_names.remove(&actor_name) { - if let Some((_, handle)) = self.local_actors.remove(&local_id) { + if let Some(actor_name) = self.registry.get_actor_name_by_path(&path_key) { + if let Some((_, local_id)) = self.registry.remove_by_name(&actor_name) { + if let Some((_, handle)) = self.registry.remove_handle(&local_id) { self.stop_local_actor( &actor_name, handle, @@ -98,7 +92,7 @@ impl ActorSystem { /// Shutdown the entire actor system /// - pub async fn shutdown(&self) -> anyhow::Result<()> { + pub async fn shutdown(&self) -> Result<()> { tracing::info!("Shutting down actor system"); self.cancel_token.cancel(); @@ -106,13 +100,14 @@ impl ActorSystem { tokio::time::sleep(Duration::from_millis(100)).await; let actor_entries: Vec<_> = self - .local_actors - .iter() + .registry + .iter_actors() .map(|entry| { let local_id = *entry.key(); let actor_id = entry.actor_id; let named_path = entry.named_path.clone(); let name = self + .registry .actor_names .iter() .find(|e| *e.value() == local_id) @@ -123,9 +118,9 @@ impl ActorSystem { .collect(); for (local_id, _actor_id, actor_name, named_path) in actor_entries { - self.actor_names.remove(&actor_name); + self.registry.remove_by_name(&actor_name); - if let Some((_, handle)) = self.local_actors.remove(&local_id) { + if let Some((_, handle)) = self.registry.remove_handle(&local_id) { self.stop_local_actor( &actor_name, handle, @@ -137,12 +132,11 @@ impl ActorSystem { } } - self.local_actors.clear(); - self.actor_names.clear(); + self.registry.clear(); self.node_load.clear(); - self.lifecycle.clear().await; + self.registry.clear_lifecycle().await; { let cluster_guard = self.cluster.read().await; @@ -201,17 +195,18 @@ impl ActorSystem { } // 3. Handle lifecycle cleanup - let local_actors = self.local_actors.clone(); - self.lifecycle + let registry = self.registry.clone(); + self.registry + .lifecycle .handle_termination( &handle.actor_id, named_path, reason, - &self.named_actor_paths, + ®istry.named_actor_paths, &self.cluster, |actor_id| { // Directly lookup by ActorId - local_actors.get(actor_id).map(|h| h.sender.clone()) + registry.get_handle(actor_id).map(|h| h.sender.clone()) }, ) .await; diff --git a/crates/pulsing-actor/src/system/mod.rs b/crates/pulsing-actor/src/system/mod.rs index 3e2a8cd4f..c436d5f9d 100644 --- a/crates/pulsing-actor/src/system/mod.rs +++ b/crates/pulsing-actor/src/system/mod.rs @@ -5,12 +5,79 @@ //! - [`SystemConfig`] - Configuration for the actor system //! - [`SpawnOptions`] - Options for spawning actors //! - [`ResolveOptions`] - Options for resolving named actors +//! +//! # Examples +//! +//! ## Creating a Standalone System +//! +//! For single-node development and testing: +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! // Create a standalone system (no network) +//! let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! +//! // The system is ready to spawn actors +//! println!("System started on node: {}", system.node_id()); +//! +//! // Clean shutdown when done +//! system.shutdown().await?; +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Creating a Cluster Node +//! +//! For production multi-node deployment: +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! // Seed node (first node in cluster) +//! let addr: std::net::SocketAddr = "0.0.0.0:8000".parse()?; +//! let config = SystemConfig::with_addr(addr); +//! let seed_system = ActorSystem::new(config).await?; +//! +//! // Worker node joining the cluster +//! let addr: std::net::SocketAddr = "0.0.0.0:8001".parse()?; +//! let seed: std::net::SocketAddr = "127.0.0.1:8000".parse()?; +//! let config = SystemConfig::with_addr(addr) +//! .with_seeds(vec![seed]); +//! let worker_system = ActorSystem::new(config).await?; +//! +//! println!("Cluster formed with 2 nodes"); +//! # Ok(()) +//! # } +//! ``` +//! +//! ## Listing Local Actors +//! +//! ```no_run +//! use pulsing_actor::prelude::*; +//! +//! # #[tokio::main] +//! # async fn main() -> anyhow::Result<()> { +//! # let system = ActorSystem::new(SystemConfig::standalone()).await?; +//! // Get all named actors in this system +//! let names = system.local_actor_names(); +//! for name in names { +//! println!("Actor: {}", name); +//! } +//! # Ok(()) +//! # } +//! ``` mod config; mod handle; mod handler; mod lifecycle; mod load_balancer; +pub mod registry; mod resolve; mod runtime; mod spawn; @@ -21,16 +88,16 @@ pub use config::{ }; pub use handle::ActorStats; pub use load_balancer::NodeLoadTracker; +pub use registry::ActorRegistry; pub use traits::{ActorSystemCoreExt, ActorSystemOpsExt}; use crate::actor::{ActorId, ActorPath, ActorRef, ActorResolver, ActorSystemRef, Envelope, NodeId}; use crate::cluster::{GossipBackend, HeadNodeBackend, NamingBackend}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::policies::{LoadBalancingPolicy, RoundRobinPolicy}; use crate::system_actor::{BoxedActorFactory, SystemActor, SystemRef, SYSTEM_ACTOR_PATH}; use crate::transport::Http2Transport; -use crate::watch::ActorLifecycle; use dashmap::DashMap; -use handle::LocalActorHandle; use handler::SystemMessageHandler; use std::net::SocketAddr; use std::sync::Arc; @@ -38,7 +105,10 @@ use tokio::sync::mpsc; use tokio::sync::RwLock; use tokio_util::sync::CancellationToken; -/// The Actor System - manages actors and cluster membership +/// The Actor System - manages actors and cluster membership. +/// +/// Actor management (spawn, name lookup, lifecycle) is delegated to +/// [`ActorRegistry`]. Transport, cluster, and load balancing remain here. pub struct ActorSystem { /// Local node ID pub(crate) node_id: NodeId, @@ -49,14 +119,8 @@ pub struct ActorSystem { /// Default mailbox capacity for actors pub(crate) default_mailbox_capacity: usize, - /// Local actors indexed by ActorId (O(1) lookup by ActorId) - pub(crate) local_actors: Arc>, - - /// Actor name to ActorId mapping (for name-based lookups) - pub(crate) actor_names: Arc>, - - /// Named actor path to local actor name mapping (path_string -> actor_name) - pub(crate) named_actor_paths: Arc>, + /// Actor registry: manages local actors, names, paths, lifecycle + pub(crate) registry: Arc, /// Naming backend (for discovery) pub(crate) cluster: Arc>>>, @@ -67,9 +131,6 @@ pub struct ActorSystem { /// Cancellation token pub(crate) cancel_token: CancellationToken, - /// Actor lifecycle manager (watch, termination handling) - pub(crate) lifecycle: Arc, - /// Default load balancing policy pub(crate) default_lb_policy: Arc, @@ -90,24 +151,15 @@ impl ActorSystem { } /// Create a new actor system - pub async fn new(config: SystemConfig) -> anyhow::Result> { + pub async fn new(config: SystemConfig) -> Result> { let cancel_token = CancellationToken::new(); let node_id = NodeId::generate(); - let local_actors: Arc> = Arc::new(DashMap::new()); - let actor_names: Arc> = Arc::new(DashMap::new()); - let named_actor_paths: Arc> = Arc::new(DashMap::new()); + let registry = Arc::new(ActorRegistry::new()); let cluster_holder: Arc>>> = Arc::new(RwLock::new(None)); - let lifecycle = Arc::new(ActorLifecycle::new()); - // Create message handler (needs cluster reference for gossip) - let handler = SystemMessageHandler::new( - node_id, - local_actors.clone(), - actor_names.clone(), - named_actor_paths.clone(), - cluster_holder.clone(), - ); + // Create message handler (needs registry and cluster reference) + let handler = SystemMessageHandler::new(node_id, registry.clone(), cluster_holder.clone()); // Clone http2_config before moving it to transport let http2_config_for_backend = config.http2_config.clone(); @@ -165,13 +217,10 @@ impl ActorSystem { node_id, addr: actual_addr, default_mailbox_capacity: config.default_mailbox_capacity, - local_actors, - actor_names, - named_actor_paths, + registry, cluster: cluster_holder, transport, cancel_token, - lifecycle, default_lb_policy: Arc::new(RoundRobinPolicy::new()), node_load: Arc::new(DashMap::new()), }); @@ -183,20 +232,25 @@ impl ActorSystem { } /// Start SystemActor (internal, called during system creation) - async fn start_system_actor(self: &Arc) -> anyhow::Result<()> { + async fn start_system_actor(self: &Arc) -> Result<()> { // Create senders snapshot for SystemRef let local_actor_senders: Arc>> = Arc::new(DashMap::new()); - for entry in self.local_actors.iter() { + for entry in self.registry.iter_actors() { // Find name for this actor (reverse lookup from actor_names) - if let Some(name_entry) = self.actor_names.iter().find(|e| *e.value() == *entry.key()) { + if let Some(name_entry) = self + .registry + .actor_names + .iter() + .find(|e| *e.value() == *entry.key()) + { local_actor_senders.insert(name_entry.key().clone(), entry.sender.clone()); } } // Create named_actor_paths snapshot let named_actor_paths: Arc> = Arc::new(DashMap::new()); - for entry in self.named_actor_paths.iter() { + for entry in self.registry.iter_named_paths() { named_actor_paths.insert(entry.key().clone(), entry.value().clone()); } @@ -229,18 +283,24 @@ impl ActorSystem { pub async fn start_system_actor_with_factory( self: &Arc, factory: BoxedActorFactory, - ) -> anyhow::Result<()> { + ) -> Result<()> { // Check if already started - if self.actor_names.contains_key(SYSTEM_ACTOR_PATH) { - return Err(anyhow::anyhow!("SystemActor already started")); + if self.registry.has_name(SYSTEM_ACTOR_PATH) { + return Err(PulsingError::from(RuntimeError::Other( + "SystemActor already started".into(), + ))); } - // Create SystemRef + // Create SystemRef (snapshot of named paths) + let named_paths_snapshot: Arc> = Arc::new(DashMap::new()); + for entry in self.registry.iter_named_paths() { + named_paths_snapshot.insert(entry.key().clone(), entry.value().clone()); + } let system_ref = Arc::new(SystemRef { node_id: self.node_id, addr: self.addr, local_actors: Arc::new(DashMap::new()), // Will be updated - named_actor_paths: self.named_actor_paths.clone(), + named_actor_paths: named_paths_snapshot, }); // Create SystemActor with custom factory @@ -261,7 +321,7 @@ impl ActorSystem { } /// Get SystemActor reference - pub async fn system(&self) -> anyhow::Result { + pub async fn system(&self) -> Result { self.resolve_named(&ActorPath::new_system(SYSTEM_ACTOR_PATH)?, None) .await } @@ -278,7 +338,7 @@ impl ActorSystem { /// Get list of local actor names pub fn local_actor_names(&self) -> Vec { - self.actor_names.iter().map(|e| e.key().clone()).collect() + self.registry.actor_names_list() } /// Get a local actor reference by name @@ -286,17 +346,13 @@ impl ActorSystem { /// Returns None if the actor doesn't exist locally. /// This is an O(1) operation. pub fn local_actor_ref_by_name(&self, name: &str) -> Option { - self.actor_names.get(name).and_then(|local_id| { - self.local_actors - .get(local_id.value()) - .map(|handle| ActorRef::local(handle.actor_id, handle.sender.clone())) - }) + self.registry.local_actor_ref_by_name(name) } } #[async_trait::async_trait] impl ActorSystemRef for ActorSystem { - async fn actor_ref(&self, id: &ActorId) -> anyhow::Result { + async fn actor_ref(&self, id: &ActorId) -> Result { ActorSystem::actor_ref(self, id).await } @@ -304,21 +360,21 @@ impl ActorSystemRef for ActorSystem { self.node_id } - async fn watch(&self, watcher: &ActorId, target: &ActorId) -> anyhow::Result<()> { + async fn watch(&self, watcher: &ActorId, target: &ActorId) -> Result<()> { // Check if target is a local actor - if !self.local_actors.contains_key(target) { - return Err(anyhow::anyhow!( + if self.registry.get_handle(target).is_none() { + return Err(PulsingError::from(RuntimeError::Other(format!( "Cannot watch remote actor: {} (watching remote actors not yet supported)", target - )); + )))); } - self.lifecycle.watch(watcher, target).await; + self.registry.lifecycle.watch(watcher, target).await; Ok(()) } - async fn unwatch(&self, watcher: &ActorId, target: &ActorId) -> anyhow::Result<()> { - self.lifecycle.unwatch(watcher, target).await; + async fn unwatch(&self, watcher: &ActorId, target: &ActorId) -> Result<()> { + self.registry.lifecycle.unwatch(watcher, target).await; Ok(()) } @@ -332,7 +388,7 @@ impl ActorSystemRef for ActorSystem { /// This enables lazy ActorRef to resolve named actors on demand. #[async_trait::async_trait] impl ActorResolver for ActorSystem { - async fn resolve_path(&self, path: &ActorPath) -> anyhow::Result { + async fn resolve_path(&self, path: &ActorPath) -> Result { // Use direct resolution (not lazy) to avoid infinite recursion self.resolve_named_direct(path, None).await } diff --git a/crates/pulsing-actor/src/system/registry.rs b/crates/pulsing-actor/src/system/registry.rs new file mode 100644 index 000000000..577d9e699 --- /dev/null +++ b/crates/pulsing-actor/src/system/registry.rs @@ -0,0 +1,187 @@ +//! Actor Registry - manages local actor instances, name mappings, and lifecycle. +//! +//! This module extracts actor management concerns from ActorSystem into a +//! focused subsystem, reducing ActorSystem's responsibilities. + +use crate::actor::{ActorId, ActorRef, Envelope}; +use crate::watch::ActorLifecycle; +use dashmap::DashMap; +use tokio::sync::mpsc; + +use super::handle::LocalActorHandle; + +/// Actor Registry - manages local actor instances and name resolution. +/// +/// Extracted from ActorSystem to separate actor management concerns from +/// transport, cluster, and system configuration. +/// +/// Responsibilities: +/// - Local actor instance storage (ActorId → handle) +/// - Name-to-ActorId mapping (name → ActorId) +/// - Named actor path mapping (path → actor_name) +/// - Actor lifecycle management (watchers, termination) +pub struct ActorRegistry { + /// Local actors indexed by ActorId (O(1) lookup) + pub(crate) local_actors: DashMap, + + /// Actor name to ActorId mapping (for name-based lookups) + pub(crate) actor_names: DashMap, + + /// Named actor path to local actor name mapping (path_string -> actor_name) + pub(crate) named_actor_paths: DashMap, + + /// Actor lifecycle manager (watch, termination handling) + pub(crate) lifecycle: ActorLifecycle, +} + +impl ActorRegistry { + /// Create a new empty registry + pub fn new() -> Self { + Self { + local_actors: DashMap::new(), + actor_names: DashMap::new(), + named_actor_paths: DashMap::new(), + lifecycle: ActorLifecycle::new(), + } + } + + // ========================================================================= + // Registration + // ========================================================================= + + /// Register a local actor handle + pub(crate) fn register_actor(&self, actor_id: ActorId, handle: LocalActorHandle) { + self.local_actors.insert(actor_id, handle); + } + + /// Register a name-to-ActorId mapping + pub fn register_name(&self, name: String, actor_id: ActorId) { + self.actor_names.insert(name, actor_id); + } + + /// Register a named actor path mapping + pub fn register_named_path(&self, path: String, actor_name: String) { + self.named_actor_paths.insert(path, actor_name); + } + + /// Check if a name is already registered + pub fn has_name(&self, name: &str) -> bool { + self.actor_names.contains_key(name) + } + + /// Check if a named path is already registered + pub fn has_named_path(&self, path: &str) -> bool { + self.named_actor_paths.contains_key(path) + } + + // ========================================================================= + // Lookup + // ========================================================================= + + /// Get a local actor handle by ActorId + pub(crate) fn get_handle( + &self, + actor_id: &ActorId, + ) -> Option> { + self.local_actors.get(actor_id) + } + + /// Get ActorId by name + pub fn get_actor_id(&self, name: &str) -> Option { + self.actor_names.get(name).map(|r| *r.value()) + } + + /// Get actor name from named path + pub fn get_actor_name_by_path(&self, path: &str) -> Option { + self.named_actor_paths.get(path).map(|r| r.value().clone()) + } + + /// Get a local actor reference by name (O(1) via name → id → handle) + pub fn local_actor_ref_by_name(&self, name: &str) -> Option { + self.actor_names.get(name).and_then(|actor_id| { + self.local_actors + .get(actor_id.value()) + .map(|handle| ActorRef::local(handle.actor_id, handle.sender.clone())) + }) + } + + /// Find actor sender by name or ActorId string (for HTTP handler) + pub fn find_actor_sender(&self, actor_name: &str) -> Option> { + // First try by name → ActorId → handle + if let Some(actor_id) = self.actor_names.get(actor_name) { + if let Some(handle) = self.local_actors.get(actor_id.value()) { + return Some(handle.sender.clone()); + } + } + + // Then try parsing as ActorId (UUID format) + if let Ok(uuid) = uuid::Uuid::parse_str(actor_name) { + let actor_id = ActorId::new(uuid.as_u128()); + if let Some(handle) = self.local_actors.get(&actor_id) { + return Some(handle.sender.clone()); + } + } + + None + } + + /// Get list of all local actor names + pub fn actor_names_list(&self) -> Vec { + self.actor_names.iter().map(|e| e.key().clone()).collect() + } + + // ========================================================================= + // Removal + // ========================================================================= + + /// Remove an actor by name, returning (name, actor_id) + pub fn remove_by_name(&self, name: &str) -> Option<(String, ActorId)> { + self.actor_names.remove(name) + } + + /// Remove a local actor handle, returning it + pub(crate) fn remove_handle(&self, actor_id: &ActorId) -> Option<(ActorId, LocalActorHandle)> { + self.local_actors.remove(actor_id) + } + + // ========================================================================= + // Iteration (for shutdown, health checks, etc.) + // ========================================================================= + + /// Iterate over all local actors (for health checks, metrics, etc.) + pub(crate) fn iter_actors(&self) -> dashmap::iter::Iter<'_, ActorId, LocalActorHandle> { + self.local_actors.iter() + } + + /// Iterate over named actor paths + pub fn iter_named_paths(&self) -> dashmap::iter::Iter<'_, String, String> { + self.named_actor_paths.iter() + } + + /// Get count of local actors + pub fn actor_count(&self) -> usize { + self.local_actors.len() + } + + // ========================================================================= + // Bulk operations (for shutdown) + // ========================================================================= + + /// Clear all registrations (for shutdown) + pub fn clear(&self) { + self.local_actors.clear(); + self.actor_names.clear(); + self.named_actor_paths.clear(); + } + + /// Clear lifecycle watchers + pub async fn clear_lifecycle(&self) { + self.lifecycle.clear().await; + } +} + +impl Default for ActorRegistry { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/pulsing-actor/src/system/resolve.rs b/crates/pulsing-actor/src/system/resolve.rs index cce4959a5..214260d24 100644 --- a/crates/pulsing-actor/src/system/resolve.rs +++ b/crates/pulsing-actor/src/system/resolve.rs @@ -7,7 +7,7 @@ use crate::actor::{ ActorAddress, ActorId, ActorPath, ActorRef, ActorResolver, IntoActorPath, NodeId, }; use crate::cluster::{MemberInfo, MemberStatus, NamedActorInfo}; -use crate::error::{PulsingError, RuntimeError}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::policies::LoadBalancingPolicy; use crate::system::config::ResolveOptions; use crate::system::load_balancer::{MemberWorker, NodeLoadTracker}; @@ -22,35 +22,33 @@ impl ActorSystem { self.cluster.read().await.as_ref().cloned() } - async fn cluster_or_err(&self) -> anyhow::Result> { + async fn cluster_or_err(&self) -> Result> { self.cluster_opt() .await - .ok_or_else(|| anyhow::anyhow!("Cluster not initialized")) + .ok_or_else(|| PulsingError::from(RuntimeError::ClusterNotInitialized)) } /// Get ActorRef for a local or remote actor by ID /// /// This is an O(1) operation for local actors using ActorId indexing. - pub async fn actor_ref(&self, id: &ActorId) -> anyhow::Result { + pub async fn actor_ref(&self, id: &ActorId) -> Result { // Try local lookup first (O(1)) - if let Some(handle) = self.local_actors.get(id) { + if let Some(handle) = self.registry.get_handle(id) { return Ok(ActorRef::local(handle.actor_id, handle.sender.clone())); } // Not found locally - try remote lookup via cluster - // Note: With UUID-based IDs, we need to check cluster for actor location let cluster = self.cluster_or_err().await?; // Lookup actor location in cluster if let Some(member_info) = cluster.lookup_actor(id).await { - // Create remote transport using actor id let transport = Http2RemoteTransport::new_by_id(self.transport.client(), member_info.addr, *id); return Ok(ActorRef::remote(*id, member_info.addr, Arc::new(transport))); } - Err(anyhow::Error::from(PulsingError::from( - RuntimeError::actor_not_found(id.to_string()), + Err(PulsingError::from(RuntimeError::actor_not_found( + id.to_string(), ))) } @@ -64,11 +62,7 @@ impl ActorSystem { /// ```rust,ignore /// let actor = system.resolve_named("services/echo", None).await?; /// ``` - pub async fn resolve_named

( - &self, - path: P, - node_id: Option<&NodeId>, - ) -> anyhow::Result + pub async fn resolve_named

(&self, path: P, node_id: Option<&NodeId>) -> Result where P: IntoActorPath, { @@ -91,7 +85,7 @@ impl ActorSystem { /// let actor = system.resolve_named_lazy("services/echo").await?; /// // Even if the actor migrates, this ref will find it after cache expires /// ``` - pub fn resolve_named_lazy

(self: &Arc, path: P) -> anyhow::Result + pub fn resolve_named_lazy

(self: &Arc, path: P) -> Result where P: IntoActorPath, { @@ -104,7 +98,7 @@ impl ActorSystem { &self, path: &ActorPath, node_id: Option<&NodeId>, - ) -> anyhow::Result { + ) -> Result { let options = if let Some(nid) = node_id { ResolveOptions::default().node_id(*nid) } else { @@ -118,13 +112,36 @@ impl ActorSystem { &self, path: &ActorPath, options: ResolveOptions, - ) -> anyhow::Result { + ) -> Result { let cluster = self.cluster_or_err().await?; let instances = cluster.get_named_actor_instances(path).await; + // When node_id is specified but gossip hasn't propagated the path yet (e.g. Topic + // subscriber on another node), resolve by target node directly instead of failing. if instances.is_empty() { - return Err(anyhow::anyhow!("Named actor not found: {}", path.as_str())); + if let Some(nid) = options.node_id { + if nid != self.node_id { + if let Some(member) = cluster.get_member(&nid).await { + if !options.filter_alive || member.status == MemberStatus::Alive { + let transport = Http2RemoteTransport::new_named( + self.transport.client(), + member.addr, + path.clone(), + ); + let actor_id = ActorId::generate(); + return Ok(ActorRef::remote( + actor_id, + member.addr, + Arc::new(transport), + )); + } + } + } + } + return Err(PulsingError::from(RuntimeError::named_actor_not_found( + path.as_str(), + ))); } let healthy_instances: Vec<_> = if options.filter_alive { @@ -137,17 +154,16 @@ impl ActorSystem { }; if healthy_instances.is_empty() { - return Err(anyhow::anyhow!( - "No healthy instances for named actor: {}", - path.as_str() - )); + return Err(PulsingError::from(RuntimeError::no_healthy_instances( + path.as_str(), + ))); } let target = if let Some(nid) = options.node_id { healthy_instances .iter() .find(|i| i.node_id == nid) - .ok_or_else(|| anyhow::anyhow!("Actor instance not found on node: {}", nid))? + .ok_or_else(|| PulsingError::from(RuntimeError::node_not_found(nid.to_string())))? } else { let policy = options.policy.as_ref().unwrap_or(&self.default_lb_policy); self.select_instance(&healthy_instances, policy.as_ref()) @@ -155,21 +171,19 @@ impl ActorSystem { if target.node_id == self.node_id { let actor_name = self - .named_actor_paths - .get(&path.as_str()) - .ok_or_else(|| anyhow::anyhow!("Named actor not found locally"))? - .clone(); - - let local_id = self.actor_names.get(&actor_name).ok_or_else(|| { - anyhow::Error::from(PulsingError::from(RuntimeError::actor_not_found( - actor_name.clone(), - ))) + .registry + .get_actor_name_by_path(&path.as_str()) + .ok_or_else(|| { + PulsingError::from(RuntimeError::named_actor_not_found(path.as_str())) + })?; + + let local_id = self.registry.get_actor_id(&actor_name).ok_or_else(|| { + PulsingError::from(RuntimeError::actor_not_found(actor_name.clone())) })?; - let handle = self - .local_actors - .get(local_id.value()) - .ok_or_else(|| anyhow::anyhow!("Actor handle not found: {}", actor_name))?; + let handle = self.registry.get_handle(&local_id).ok_or_else(|| { + PulsingError::from(RuntimeError::actor_not_found(actor_name.clone())) + })?; return Ok(ActorRef::local(handle.actor_id, handle.sender.clone())); } @@ -256,7 +270,7 @@ impl ActorSystem { } /// Resolve an actor address and get an ActorRef - pub async fn resolve(&self, address: &ActorAddress) -> anyhow::Result { + pub async fn resolve(&self, address: &ActorAddress) -> Result { match address { ActorAddress::Named { path, instance } => { self.resolve_named(path, instance.as_ref()).await @@ -281,7 +295,7 @@ impl ActorSystem { &self, path: &ActorPath, filter_alive: bool, - ) -> anyhow::Result> { + ) -> Result> { let cluster = self.cluster_or_err().await?; let instances = cluster.get_named_actor_instances_detailed(path).await; @@ -338,3 +352,59 @@ impl ActorSystem { } } } + +#[cfg(test)] +mod tests { + use crate::system::config::ActorSystemBuilder; + + #[tokio::test] + async fn test_resolve_named_invalid_path_single_segment() { + let system = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .build() + .await + .unwrap(); + let err = system.resolve_named("x", None).await.unwrap_err(); + let msg = err.to_string().to_lowercase(); + assert!( + msg.contains("namespace") || msg.contains("path"), + "expected path/namespace error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_resolve_named_lazy_invalid_path() { + let system = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .build() + .await + .unwrap(); + let err = system.resolve_named_lazy("single").unwrap_err(); + let msg = err.to_string().to_lowercase(); + assert!( + msg.contains("namespace") || msg.contains("path"), + "expected path/namespace error, got: {}", + err + ); + } + + #[tokio::test] + async fn test_resolve_named_valid_path_no_instances() { + let system = ActorSystemBuilder::default() + .addr("127.0.0.1:0") + .build() + .await + .unwrap(); + let err = system + .resolve_named("svc/nonexistent", None) + .await + .unwrap_err(); + let msg = err.to_string(); + assert!( + msg.contains("nonexistent") || msg.contains("not_found") || msg.contains("not found"), + "expected named_actor_not_found, got: {}", + err + ); + } +} diff --git a/crates/pulsing-actor/src/system/runtime.rs b/crates/pulsing-actor/src/system/runtime.rs index d75ac22fb..143734a68 100644 --- a/crates/pulsing-actor/src/system/runtime.rs +++ b/crates/pulsing-actor/src/system/runtime.rs @@ -35,10 +35,9 @@ pub(crate) async fn run_actor_instance( responder.send(Ok(response)); } Err(e) => { - tracing::error!(actor_id = ?ctx.id(), error = %e, "Actor error"); - responder.send(Err(anyhow::anyhow!("Handler error: {}", e))); - // Actor crashes on error - supervision will decide whether to restart - return StopReason::Failed(e.to_string()); + // 业务错误:receive 返回 Err,只把错误返回给调用者,actor 继续处理下一条消息 + tracing::warn!(actor_id = ?ctx.id(), error = %e, "Receive returned error (returned to caller)"); + responder.send(Err(e)); } } } @@ -77,7 +76,7 @@ pub(crate) async fn run_supervision_loop( spec: SupervisionSpec, ) -> StopReason where - F: FnMut() -> anyhow::Result + Send + 'static, + F: FnMut() -> crate::error::Result + Send + 'static, A: Actor, { let mut restarts = 0; diff --git a/crates/pulsing-actor/src/system/spawn.rs b/crates/pulsing-actor/src/system/spawn.rs index 14adee7f6..1bca313a4 100644 --- a/crates/pulsing-actor/src/system/spawn.rs +++ b/crates/pulsing-actor/src/system/spawn.rs @@ -7,11 +7,12 @@ //! All other spawn methods delegate to the builder. use crate::actor::{Actor, ActorContext, ActorId, ActorPath, ActorRef, ActorSystemRef, Mailbox}; -use crate::error::{PulsingError, RuntimeError}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::system::config::SpawnOptions; use crate::system::handle::{ActorStats, LocalActorHandle}; use crate::system::runtime::run_supervision_loop; use crate::system::ActorSystem; +use dashmap::mapref::entry::Entry; use std::sync::Arc; impl ActorSystem { @@ -19,30 +20,21 @@ impl ActorSystem { /// /// This is called by `SpawnBuilder::spawn_factory()` and handles both /// anonymous and named actor spawning. + /// + /// Name registration uses DashMap::entry() for atomic insert, avoiding + /// TOCTOU races when two concurrent spawns use the same name. pub(crate) async fn spawn_internal( self: &Arc, path: Option, factory: F, options: SpawnOptions, - ) -> anyhow::Result + ) -> Result where - F: FnMut() -> anyhow::Result + Send + 'static, + F: FnMut() -> Result + Send + 'static, A: Actor, { let name_str = path.as_ref().map(|p| p.as_str().to_string()); - // Check for name conflicts (only for named actors) - if let Some(ref name) = name_str { - if self.actor_names.contains_key(name) { - return Err(anyhow::Error::from(PulsingError::from( - RuntimeError::actor_already_exists(name.clone()), - ))); - } - if self.named_actor_paths.contains_key(name) { - return Err(anyhow::anyhow!("Named path already registered: {}", name)); - } - } - let actor_id = self.next_actor_id(); let mailbox = Mailbox::with_capacity(self.mailbox_capacity(&options)); @@ -77,12 +69,25 @@ impl ActorSystem { actor_id, }; - self.local_actors.insert(actor_id, handle); + self.registry.register_actor(actor_id, handle); - // Register in name maps + // Register in name maps. For named actors use atomic entry() to avoid TOCTOU. if let Some(ref name) = name_str { - self.actor_names.insert(name.clone(), actor_id); - self.named_actor_paths.insert(name.clone(), name.clone()); + match self.registry.actor_names.entry(name.clone()) { + Entry::Occupied(_) => { + if let Some((_, dropped_handle)) = self.registry.remove_handle(&actor_id) { + dropped_handle.cancel_token.cancel(); + } + return Err(PulsingError::from(RuntimeError::actor_already_exists( + name.clone(), + ))); + } + Entry::Vacant(v) => { + v.insert(actor_id); + } + } + self.registry + .register_named_path(name.clone(), name.clone()); // Register with cluster if available if let Some(ref path) = path { @@ -98,7 +103,7 @@ impl ActorSystem { } } else { // Anonymous actor: use actor_id as key - self.actor_names.insert(actor_id.to_string(), actor_id); + self.registry.register_name(actor_id.to_string(), actor_id); } Ok(ActorRef::local(actor_id, sender)) diff --git a/crates/pulsing-actor/src/system/traits.rs b/crates/pulsing-actor/src/system/traits.rs index d24fdfae8..1f4c5ea24 100644 --- a/crates/pulsing-actor/src/system/traits.rs +++ b/crates/pulsing-actor/src/system/traits.rs @@ -6,6 +6,7 @@ use std::sync::Arc; use crate::actor::{Actor, ActorId, ActorPath, ActorRef, IntoActor, IntoActorPath, NodeId}; use crate::cluster::{MemberInfo, NamedActorInfo}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::supervision::SupervisionSpec; use crate::system_actor::BoxedActorFactory; @@ -23,7 +24,7 @@ pub trait ActorSystemCoreExt: Sized { /// Accepts any type that implements `IntoActor`, including: /// - Types implementing `Actor` directly /// - `Behavior` (automatically wrapped) - async fn spawn(&self, actor: A) -> anyhow::Result + async fn spawn(&self, actor: A) -> Result where A: IntoActor; @@ -38,11 +39,7 @@ pub trait ActorSystemCoreExt: Sized { /// # Arguments /// - `name` - The name for discovery (e.g., "services/echo") /// - `actor` - The actor instance or Behavior - async fn spawn_named( - &self, - name: impl AsRef + Send, - actor: A, - ) -> anyhow::Result + async fn spawn_named(&self, name: impl AsRef + Send, actor: A) -> Result where A: IntoActor; @@ -50,10 +47,10 @@ pub trait ActorSystemCoreExt: Sized { fn spawning(&self) -> SpawnBuilder<'_>; /// Get ActorRef for a local or remote actor by ID - async fn actor_ref(&self, id: &ActorId) -> anyhow::Result; + async fn actor_ref(&self, id: &ActorId) -> Result; /// Resolve a named actor by name. - async fn resolve

(&self, name: P) -> anyhow::Result + async fn resolve

(&self, name: P) -> Result where P: IntoActorPath + Send; @@ -139,7 +136,7 @@ impl<'a> SpawnBuilder<'a> { /// /// If a name was set, spawns a named actor (resolvable). /// Otherwise, spawns an anonymous actor (only accessible via ActorRef). - pub async fn spawn(self, actor: A) -> anyhow::Result + pub async fn spawn(self, actor: A) -> Result where A: IntoActor, { @@ -147,9 +144,11 @@ impl<'a> SpawnBuilder<'a> { // Create a once-use factory from the actor instance let mut actor_opt = Some(actor); let factory = move || { - actor_opt - .take() - .ok_or_else(|| anyhow::anyhow!("Actor cannot be restarted (spawned as instance)")) + actor_opt.take().ok_or_else(|| { + PulsingError::from(RuntimeError::actor_spawn_failed( + "Actor cannot be restarted (spawned as instance)", + )) + }) }; self.spawn_factory(factory).await } @@ -161,14 +160,16 @@ impl<'a> SpawnBuilder<'a> { /// /// Note: Only named actors support supervision/restart. Anonymous actors /// cannot be restarted because they have no stable identity for re-resolution. - pub async fn spawn_factory(self, factory: F) -> anyhow::Result + pub async fn spawn_factory(self, factory: F) -> Result where - F: FnMut() -> anyhow::Result + Send + 'static, + F: FnMut() -> Result + Send + 'static, A: Actor, { // Check if name validation failed if let Some(ref error) = self.name_error { - return Err(anyhow::anyhow!("{}", error)); + return Err(PulsingError::from(RuntimeError::invalid_actor_path( + error.clone(), + ))); } match self.name { @@ -218,7 +219,7 @@ impl<'a> ResolveBuilder<'a> { } /// Resolve a named actor - pub async fn resolve

(self, name: P) -> anyhow::Result + pub async fn resolve

(self, name: P) -> Result where P: IntoActorPath + Send, { @@ -227,7 +228,7 @@ impl<'a> ResolveBuilder<'a> { } /// List all instances of a named actor - pub async fn list

(self, name: P) -> anyhow::Result> + pub async fn list

(self, name: P) -> Result> where P: IntoActorPath + Send, { @@ -236,7 +237,7 @@ impl<'a> ResolveBuilder<'a> { } /// Lazy resolve - returns ActorRef that auto re-resolves when stale - pub fn lazy

(self, name: P) -> anyhow::Result + pub fn lazy

(self, name: P) -> Result where P: IntoActorPath, { @@ -248,13 +249,10 @@ impl<'a> ResolveBuilder<'a> { #[async_trait::async_trait] pub trait ActorSystemOpsExt { /// Get SystemActor reference - async fn system(&self) -> anyhow::Result; + async fn system(&self) -> Result; /// Start SystemActor with custom factory (for Python extension) - async fn start_system_actor_with_factory( - &self, - factory: BoxedActorFactory, - ) -> anyhow::Result<()>; + async fn start_system_actor_with_factory(&self, factory: BoxedActorFactory) -> Result<()>; /// Get node ID fn node_id(&self) -> &NodeId; @@ -290,10 +288,7 @@ pub trait ActorSystemOpsExt { fn tracked_node_count(&self) -> usize; /// Resolve an actor address and get an ActorRef - async fn resolve_address( - &self, - address: &crate::actor::ActorAddress, - ) -> anyhow::Result; + async fn resolve_address(&self, address: &crate::actor::ActorAddress) -> Result; /// Get detailed instances with actor_id and metadata async fn get_named_instances_detailed( @@ -311,27 +306,27 @@ pub trait ActorSystemOpsExt { async fn members(&self) -> Vec; /// Stop an actor by local name - async fn stop(&self, name: impl AsRef + Send) -> anyhow::Result<()>; + async fn stop(&self, name: impl AsRef + Send) -> Result<()>; /// Stop an actor with a specific reason async fn stop_with_reason( &self, name: impl AsRef + Send, reason: crate::actor::StopReason, - ) -> anyhow::Result<()>; + ) -> Result<()>; /// Stop a named actor by path - async fn stop_named(&self, path: &ActorPath) -> anyhow::Result<()>; + async fn stop_named(&self, path: &ActorPath) -> Result<()>; /// Stop a named actor by path with a specific reason async fn stop_named_with_reason( &self, path: &ActorPath, reason: crate::actor::StopReason, - ) -> anyhow::Result<()>; + ) -> Result<()>; /// Shutdown the entire actor system - async fn shutdown(&self) -> anyhow::Result<()>; + async fn shutdown(&self) -> Result<()>; /// Get cancellation token fn cancel_token(&self) -> CancellationToken; @@ -345,18 +340,14 @@ use super::ActorSystem; #[async_trait::async_trait] impl ActorSystemCoreExt for Arc { - async fn spawn(&self, actor: A) -> anyhow::Result + async fn spawn(&self, actor: A) -> Result where A: IntoActor, { self.spawning().spawn(actor).await } - async fn spawn_named( - &self, - name: impl AsRef + Send, - actor: A, - ) -> anyhow::Result + async fn spawn_named(&self, name: impl AsRef + Send, actor: A) -> Result where A: IntoActor, { @@ -367,11 +358,11 @@ impl ActorSystemCoreExt for Arc { SpawnBuilder::new(self) } - async fn actor_ref(&self, id: &ActorId) -> anyhow::Result { + async fn actor_ref(&self, id: &ActorId) -> Result { ActorSystem::actor_ref(self.as_ref(), id).await } - async fn resolve

(&self, name: P) -> anyhow::Result + async fn resolve

(&self, name: P) -> Result where P: IntoActorPath + Send, { @@ -385,14 +376,11 @@ impl ActorSystemCoreExt for Arc { #[async_trait::async_trait] impl ActorSystemOpsExt for Arc { - async fn system(&self) -> anyhow::Result { + async fn system(&self) -> Result { ActorSystem::system(self.as_ref()).await } - async fn start_system_actor_with_factory( - &self, - factory: BoxedActorFactory, - ) -> anyhow::Result<()> { + async fn start_system_actor_with_factory(&self, factory: BoxedActorFactory) -> Result<()> { ActorSystem::start_system_actor_with_factory(self, factory).await } @@ -428,10 +416,7 @@ impl ActorSystemOpsExt for Arc { ActorSystem::tracked_node_count(self.as_ref()) } - async fn resolve_address( - &self, - address: &crate::actor::ActorAddress, - ) -> anyhow::Result { + async fn resolve_address(&self, address: &crate::actor::ActorAddress) -> Result { ActorSystem::resolve(self.as_ref(), address).await } @@ -454,7 +439,7 @@ impl ActorSystemOpsExt for Arc { ActorSystem::members(self.as_ref()).await } - async fn stop(&self, name: impl AsRef + Send) -> anyhow::Result<()> { + async fn stop(&self, name: impl AsRef + Send) -> Result<()> { ActorSystem::stop(self.as_ref(), name).await } @@ -462,11 +447,11 @@ impl ActorSystemOpsExt for Arc { &self, name: impl AsRef + Send, reason: crate::actor::StopReason, - ) -> anyhow::Result<()> { + ) -> Result<()> { ActorSystem::stop_with_reason(self.as_ref(), name, reason).await } - async fn stop_named(&self, path: &ActorPath) -> anyhow::Result<()> { + async fn stop_named(&self, path: &ActorPath) -> Result<()> { ActorSystem::stop_named(self.as_ref(), path).await } @@ -474,11 +459,11 @@ impl ActorSystemOpsExt for Arc { &self, path: &ActorPath, reason: crate::actor::StopReason, - ) -> anyhow::Result<()> { + ) -> Result<()> { ActorSystem::stop_named_with_reason(self.as_ref(), path, reason).await } - async fn shutdown(&self) -> anyhow::Result<()> { + async fn shutdown(&self) -> Result<()> { ActorSystem::shutdown(self.as_ref()).await } diff --git a/crates/pulsing-actor/src/system_actor/mod.rs b/crates/pulsing-actor/src/system_actor/mod.rs index d426f64c6..28fa6ba79 100644 --- a/crates/pulsing-actor/src/system_actor/mod.rs +++ b/crates/pulsing-actor/src/system_actor/mod.rs @@ -25,6 +25,7 @@ pub use factory::{ActorFactory, BoxedActorFactory, DefaultActorFactory}; pub use messages::{ActorInfo, ActorStatusInfo, SystemMessage, SystemResponse}; use crate::actor::{Actor, ActorContext, ActorId, Message}; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::metrics::SystemMetrics as PrometheusSystemMetrics; use dashmap::DashMap; use std::collections::HashMap; @@ -294,11 +295,12 @@ impl SystemActor { } /// Generate JSON error response - fn json_error_response(&self, message: &str) -> anyhow::Result { + fn json_error_response(&self, message: &str) -> Result { let response = SystemResponse::Error { message: message.to_string(), }; - let json_data = serde_json::to_vec(&response)?; + let json_data = serde_json::to_vec(&response) + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; Ok(Message::Single { msg_type: "SystemResponse".to_string(), data: json_data, @@ -316,7 +318,7 @@ impl Actor for SystemActor { meta } - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> Result<()> { tracing::info!( actor_id = ?ctx.id(), path = SYSTEM_ACTOR_PATH, @@ -325,7 +327,7 @@ impl Actor for SystemActor { Ok(()) } - async fn on_stop(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, ctx: &mut ActorContext) -> Result<()> { tracing::info!( actor_id = ?ctx.id(), path = SYSTEM_ACTOR_PATH, @@ -334,7 +336,7 @@ impl Actor for SystemActor { Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> Result { self.metrics.inc_message(); // Parse system message using auto-detection (JSON first, then bincode) @@ -392,7 +394,7 @@ impl Actor for SystemActor { // Use JSON serialization for response (for Python compatibility) let json_data = serde_json::to_vec(&response) - .map_err(|e| anyhow::anyhow!("Failed to serialize response: {}", e))?; + .map_err(|e| PulsingError::from(RuntimeError::Serialization(e.to_string())))?; Ok(Message::Single { msg_type: "SystemResponse".to_string(), data: json_data, diff --git a/crates/pulsing-actor/src/test_helper.rs b/crates/pulsing-actor/src/test_helper.rs index fe5b4de6f..c81df1f3e 100644 --- a/crates/pulsing-actor/src/test_helper.rs +++ b/crates/pulsing-actor/src/test_helper.rs @@ -97,7 +97,11 @@ impl Default for TestEchoActor { #[async_trait] impl Actor for TestEchoActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> crate::error::Result { if msg.msg_type().ends_with("TestPing") { let ping: TestPing = msg.unpack()?; self.echo_count.fetch_add(1, Ordering::SeqCst); @@ -105,7 +109,9 @@ impl Actor for TestEchoActor { result: ping.value * 2, }); } - Err(anyhow::anyhow!("Unknown message type: {}", msg.msg_type())) + Err(crate::error::PulsingError::from( + crate::error::RuntimeError::Other(format!("Unknown message type: {}", msg.msg_type())), + )) } } @@ -131,7 +137,11 @@ impl Default for TestAccumulatorActor { #[async_trait] impl Actor for TestAccumulatorActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> crate::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("TestAccumulate") { let acc: TestAccumulate = msg.unpack()?; @@ -141,7 +151,9 @@ impl Actor for TestAccumulatorActor { if msg_type.ends_with("TestGetTotal") { return Message::pack(&TestTotalResponse { total: self.total }); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(crate::error::PulsingError::from( + crate::error::RuntimeError::Other(format!("Unknown message type: {}", msg_type)), + )) } } @@ -251,10 +263,10 @@ macro_rules! actor_test { macro_rules! actor_test_result { ($test_name:ident, $system:ident, $test_body:block) => { #[tokio::test] - async fn $test_name() -> anyhow::Result<()> { + async fn $test_name() -> $crate::error::Result<()> { let $system = $crate::test_helper::create_test_system().await; // Execute the test body - let test_result: anyhow::Result<()> = $test_body; + let test_result: $crate::error::Result<()> = $test_body; // Shutdown the system regardless of test result $system.shutdown().await?; test_result diff --git a/crates/pulsing-actor/src/transport/http2/client.rs b/crates/pulsing-actor/src/transport/http2/client.rs index 5940caa1a..27ca75e0e 100644 --- a/crates/pulsing-actor/src/transport/http2/client.rs +++ b/crates/pulsing-actor/src/transport/http2/client.rs @@ -6,7 +6,7 @@ use super::retry::{RetryConfig, RetryExecutor}; use super::stream::{BinaryFrameParser, StreamFrame, StreamHandle}; use super::{headers, MessageMode, RequestType}; use crate::actor::{Message, MessageStream}; -use crate::error::RuntimeError; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::tracing::{TraceContext, TRACEPARENT_HEADER}; use bytes::Bytes; use futures::{Stream, StreamExt, TryStreamExt}; @@ -20,12 +20,36 @@ use std::time::Duration; use tokio::sync::Mutex; use tokio_util::sync::CancellationToken; +/// Context for fault injection (testing / chaos). +#[derive(Clone, Debug)] +pub struct FaultInjectContext { + pub addr: SocketAddr, + pub path: String, + pub msg_type: String, + pub operation: FaultInjectOperation, +} + +/// Operation kind for fault injection. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FaultInjectOperation { + Ask, + Tell, + Stream, +} + +/// Fault injector for testing: optionally return an error before performing the request. +pub trait FaultInjector: Send + Sync { + /// If returns Some(error), the client will return this error without sending the request. + fn inject(&self, ctx: &FaultInjectContext) -> Option; +} + /// HTTP/2 client with connection pooling, retry, and timeout support. pub struct Http2Client { pool: Arc, config: Http2Config, retry_config: RetryConfig, cancel: CancellationToken, + fault_injector: Option>, } impl Http2Client { @@ -35,6 +59,7 @@ impl Http2Client { config, retry_config: RetryConfig::default(), cancel: CancellationToken::new(), + fault_injector: None, } } @@ -51,6 +76,7 @@ impl Http2Client { config: http2_config, retry_config, cancel: CancellationToken::new(), + fault_injector: None, } } @@ -59,6 +85,15 @@ impl Http2Client { self } + /// Set fault injector for testing / chaos engineering. When set, injector may return + /// an error before the request is sent. + pub fn with_fault_injector(self, injector: Option>) -> Self { + Self { + fault_injector: injector, + ..self + } + } + pub fn pool(&self) -> &Arc { &self.pool } @@ -75,15 +110,36 @@ impl Http2Client { self.cancel.cancel(); } + fn check_fault_inject( + &self, + addr: SocketAddr, + path: &str, + msg_type: &str, + op: FaultInjectOperation, + ) -> Result<()> { + if let Some(ref injector) = self.fault_injector { + let ctx = FaultInjectContext { + addr, + path: path.to_string(), + msg_type: msg_type.to_string(), + operation: op, + }; + if let Some(err) = injector.inject(&ctx) { + return Err(err); + } + } + Ok(()) + } + pub async fn ask( &self, addr: SocketAddr, path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result> { + ) -> Result> { + self.check_fault_inject(addr, path, msg_type, FaultInjectOperation::Ask)?; let executor = RetryExecutor::new(self.retry_config.clone()); - executor .execute(true, || { self.ask_once(addr, path, msg_type, payload.clone()) @@ -97,7 +153,7 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result> { + ) -> Result> { let response = self .send_request(addr, path, msg_type, payload, MessageMode::Ask) .await?; @@ -106,17 +162,18 @@ impl Http2Client { let body = tokio::time::timeout(self.config.request_timeout, response.collect()) .await - .map_err(|_| anyhow::anyhow!("Response body read timeout"))? - .map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))? + .map_err(|_| { + RuntimeError::request_timeout(self.config.request_timeout.as_millis() as u64) + })? + .map_err(|e| RuntimeError::Io(e.to_string()))? .to_bytes(); if !status.is_success() { let error_msg = String::from_utf8_lossy(&body); - return Err(anyhow::anyhow!( - "Request failed with status {}: {}", - status, - error_msg - )); + return Err(PulsingError::from(RuntimeError::invalid_response(format!( + "status {}: {}", + status, error_msg + )))); } Ok(body.to_vec()) @@ -128,9 +185,9 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result<()> { + ) -> Result<()> { + self.check_fault_inject(addr, path, msg_type, FaultInjectOperation::Tell)?; let executor = RetryExecutor::new(self.retry_config.clone()); - executor .execute(false, || { self.tell_once(addr, path, msg_type, payload.clone()) @@ -144,7 +201,7 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result<()> { + ) -> Result<()> { let response = self .send_request(addr, path, msg_type, payload, MessageMode::Tell) .await?; @@ -152,14 +209,16 @@ impl Http2Client { let status = response.status(); if !status.is_success() { - let body = response.collect().await?.to_bytes(); + let body = response + .collect() + .await + .map_err(|e| RuntimeError::Io(e.to_string()))? + .to_bytes(); let error_msg = String::from_utf8_lossy(&body); - return Err(anyhow::anyhow!( - "Tell failed with status {} to {}: {}", - status, - addr, - error_msg - )); + return Err(PulsingError::from(RuntimeError::invalid_response(format!( + "tell status {} to {}: {}", + status, addr, error_msg + )))); } Ok(()) @@ -172,7 +231,8 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result> { + ) -> Result> { + self.check_fault_inject(addr, path, msg_type, FaultInjectOperation::Stream)?; let response = self .send_request(addr, path, msg_type, payload, MessageMode::Stream) .await?; @@ -180,19 +240,21 @@ impl Http2Client { let status = response.status(); if !status.is_success() { - let body = response.collect().await?.to_bytes(); + let body = response + .collect() + .await + .map_err(|e| RuntimeError::Io(e.to_string()))? + .to_bytes(); let error_msg = String::from_utf8_lossy(&body); - return Err(anyhow::anyhow!( - "Stream request failed with status {}: {}", - status, - error_msg - )); + return Err(PulsingError::from(RuntimeError::invalid_response(format!( + "stream status {}: {}", + status, error_msg + )))); } let cancel = CancellationToken::new(); let cancel_clone = cancel.clone(); - // Apply stream timeout let stream_timeout = self.config.stream_timeout; let body_stream = response.into_body(); let frame_stream = Self::body_to_frame_stream(body_stream, cancel_clone, stream_timeout); @@ -209,21 +271,17 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result { + ) -> Result { let stream_handle = self.ask_stream(addr, path, msg_type, payload).await?; - // Convert StreamFrame stream to Message stream using to_message() let msg_stream = stream_handle.filter_map(|result| async move { match result { - Ok(frame) => { - // Use the new to_message() method - match frame.to_message() { - Ok(Some(msg)) => Some(Ok(msg)), - Ok(None) => None, // End frame with no data - Err(e) => Some(Err(e)), - } - } - Err(e) => Some(Err(e)), + Ok(frame) => match frame.to_message() { + Ok(Some(msg)) => Some(Ok(msg)), + Ok(None) => None, + Err(e) => Some(Err(e)), + }, + Err(e) => Some(Err(PulsingError::from(RuntimeError::Other(e.to_string())))), } }); @@ -238,7 +296,7 @@ impl Http2Client { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result { + ) -> Result { let response = self .send_request(addr, path, msg_type, payload, MessageMode::Ask) .await?; @@ -256,17 +314,15 @@ impl Http2Client { addr: SocketAddr, path: &str, msg: Message, - ) -> anyhow::Result { + ) -> Result { match msg { Message::Single { msg_type, data } => { - // Single request - use existing method self.send_message(addr, path, &msg_type, data).await } Message::Stream { default_msg_type, stream, } => { - // Streaming request - send as binary frames let response = self .send_stream_request(addr, path, &default_msg_type, stream) .await?; @@ -285,7 +341,7 @@ impl Http2Client { path: &str, msg_type: &str, stream: MessageStream, - ) -> anyhow::Result> { + ) -> Result> { use hyper::client::conn::http2; use hyper_util::rt::{TokioExecutor, TokioIo}; use tokio::net::TcpStream; @@ -318,19 +374,22 @@ impl Http2Client { .map_err(|e| RuntimeError::connection_failed(addr.to_string(), e.to_string()))?; // Build HTTP/2 connection with streaming body type - with or without TLS - type StreamingBody = - StreamBody, Infallible>>>; + type StreamingBody = StreamBody< + tokio_stream::wrappers::ReceiverStream, Infallible>>, + >; #[cfg(feature = "tls")] if let Some(ref tls_config) = self.config.tls { - // TLS mode: wrap TCP stream with TLS let server_name = addr.ip().to_string(); - let tls_stream = tls_config.connect(tcp_stream, &server_name).await?; + let tls_stream = tls_config + .connect(tcp_stream, &server_name) + .await + .map_err(|e| RuntimeError::tls_error(e.to_string()))?; let io = TokioIo::new(tls_stream); let (mut sender, conn): (http2::SendRequest, _) = http2::handshake(TokioExecutor::new(), io) .await - .map_err(|e| anyhow::anyhow!("HTTP/2 TLS handshake failed: {}", e))?; + .map_err(|e| RuntimeError::tls_error(e.to_string()))?; // Spawn connection driver for TLS let cancel = self.cancel.clone(); @@ -348,7 +407,8 @@ impl Http2Client { }); // Complete the streaming request (TLS path) - let (tx, rx) = tokio::sync::mpsc::channel::, Infallible>>(32); + let (tx, rx) = + tokio::sync::mpsc::channel::, Infallible>>(32); let default_msg_type = msg_type.to_string(); tokio::spawn(async move { let mut stream = std::pin::pin!(stream); @@ -386,8 +446,10 @@ impl Http2Client { let send_future = sender.send_request(request); let response = tokio::time::timeout(self.config.stream_timeout, send_future) .await - .map_err(|_| anyhow::anyhow!("Streaming request timeout"))? - .map_err(|e| anyhow::anyhow!("Streaming request failed: {}", e))?; + .map_err(|_| { + RuntimeError::request_timeout(self.config.stream_timeout.as_millis() as u64) + })? + .map_err(|e| RuntimeError::protocol_error(e.to_string()))?; return Ok(response); } @@ -417,7 +479,8 @@ impl Http2Client { }); // Create a channel for streaming body - let (tx, rx) = tokio::sync::mpsc::channel::, Infallible>>(32); + let (tx, rx) = + tokio::sync::mpsc::channel::, Infallible>>(32); // Spawn task to convert Message stream to binary frames let default_msg_type = msg_type.to_string(); @@ -473,19 +536,21 @@ impl Http2Client { } /// Parse HTTP response into Message (handles both single and stream responses) - async fn parse_response(&self, response: hyper::Response) -> anyhow::Result { + async fn parse_response(&self, response: hyper::Response) -> Result { let status = response.status(); if !status.is_success() { - let body = response.collect().await?.to_bytes(); + let body = response + .collect() + .await + .map_err(|e| RuntimeError::Io(e.to_string()))? + .to_bytes(); let error_msg = String::from_utf8_lossy(&body); - return Err(anyhow::anyhow!( - "Request failed: {} - {}", - status, - error_msg - )); + return Err(PulsingError::from(RuntimeError::invalid_response(format!( + "{} - {}", + status, error_msg + )))); } - // Check response type header let response_type = response .headers() .get(headers::RESPONSE_TYPE) @@ -493,7 +558,6 @@ impl Http2Client { .unwrap_or("single"); if response_type == "stream" { - // Stream response - parse binary frames let cancel = CancellationToken::new(); let cancel_clone = cancel.clone(); let stream_timeout = self.config.stream_timeout; @@ -502,15 +566,14 @@ impl Http2Client { Self::body_to_frame_stream(body_stream, cancel_clone, stream_timeout); let stream_handle = StreamHandle::new(frame_stream, cancel); - // Convert StreamFrame stream to Message stream let msg_stream = stream_handle.filter_map(|result| async move { match result { Ok(frame) => match frame.to_message() { Ok(Some(msg)) => Some(Ok(msg)), - Ok(None) => None, // End frame + Ok(None) => None, Err(e) => Some(Err(e)), }, - Err(e) => Some(Err(e)), + Err(e) => Some(Err(PulsingError::from(RuntimeError::Other(e.to_string())))), } }); @@ -519,18 +582,19 @@ impl Http2Client { stream: Box::pin(msg_stream), }) } else { - // Single response - read body directly let body = tokio::time::timeout(self.config.request_timeout, response.collect()) .await - .map_err(|_| anyhow::anyhow!("Response body read timeout"))? - .map_err(|e| anyhow::anyhow!("Failed to read body: {}", e))? + .map_err(|_| { + RuntimeError::request_timeout(self.config.request_timeout.as_millis() as u64) + })? + .map_err(|e| RuntimeError::Io(e.to_string()))? .to_bytes(); Ok(Message::single("", body.to_vec())) } } - /// Convert response body to stream of StreamFrames using binary format + /// Convert response body to stream of StreamFrames using binary format. fn body_to_frame_stream( body: Incoming, cancel: CancellationToken, @@ -548,7 +612,7 @@ impl Http2Client { .map(move |result| { let parser = parser.clone(); async move { - let frame = result.map_err(|e| anyhow::anyhow!("Body read error: {}", e))?; + let frame = result.map_err(|e| anyhow::anyhow!("Body read: {}", e))?; let data = frame .into_data() .map_err(|_| anyhow::anyhow!("Not data frame"))?; @@ -556,8 +620,11 @@ impl Http2Client { let mut parser = parser.lock().await; parser.push(&data); - // Parse all complete frames - let frames = parser.parse_all(); + let frames = parser + .parse_all() + .into_iter() + .map(|r| r.map_err(|e| anyhow::anyhow!("{}", e))) + .collect::>(); Ok::<_, anyhow::Error>(futures::stream::iter(frames)) } }) @@ -573,8 +640,7 @@ impl Http2Client { msg_type: &str, payload: Vec, mode: MessageMode, - ) -> anyhow::Result> { - // Create trace context for outgoing request + ) -> Result> { let trace_ctx = TraceContext::from_current() .map(|p| p.child()) .unwrap_or_default(); @@ -624,6 +690,7 @@ impl Clone for Http2Client { config: self.config.clone(), retry_config: self.retry_config.clone(), cancel: self.cancel.clone(), + fault_injector: self.fault_injector.clone(), } } } @@ -633,6 +700,7 @@ pub struct Http2ClientBuilder { http2_config: Http2Config, pool_config: Option, retry_config: Option, + fault_injector: Option>, } impl Http2ClientBuilder { @@ -642,9 +710,16 @@ impl Http2ClientBuilder { http2_config: Http2Config::default(), pool_config: None, retry_config: None, + fault_injector: None, } } + /// Set fault injector for testing / chaos engineering. + pub fn fault_injector(mut self, injector: Option>) -> Self { + self.fault_injector = injector; + self + } + /// Set HTTP/2 configuration pub fn http2_config(mut self, config: Http2Config) -> Self { self.http2_config = config; @@ -689,11 +764,12 @@ impl Http2ClientBuilder { /// Build the client pub fn build(self) -> Http2Client { - Http2Client::with_configs( + let client = Http2Client::with_configs( self.http2_config, self.pool_config.unwrap_or_default(), self.retry_config.unwrap_or_default(), - ) + ); + client.with_fault_injector(self.fault_injector) } } @@ -745,4 +821,134 @@ mod tests { // Both should share the same pool assert!(Arc::ptr_eq(&client.pool, &cloned.pool)); } + + // --- 连接管理 --- + + #[test] + fn test_client_pool_and_stats() { + use std::sync::atomic::Ordering; + let client = Http2Client::new(Http2Config::default()); + let pool = client.pool(); + let stats = client.stats(); + assert_eq!(stats.connections_created.load(Ordering::Relaxed), 0); + assert_eq!(stats.connections_closed.load(Ordering::Relaxed), 0); + assert!(Arc::ptr_eq(client.pool(), pool)); + } + + #[test] + fn test_client_with_retry() { + let retry = RetryConfig::no_retry(); + let client = Http2Client::new(Http2Config::default()).with_retry(retry.clone()); + assert_eq!(client.retry_config.max_retries, 0); + } + + #[test] + fn test_client_with_configs() { + let http2_config = Http2Config::default(); + let pool_config = PoolConfig::default(); + let retry_config = RetryConfig::with_max_retries(2); + let client = Http2Client::with_configs(http2_config, pool_config, retry_config); + assert_eq!(client.retry_config.max_retries, 2); + } + + #[tokio::test] + async fn test_start_background_tasks_and_shutdown() { + let client = Http2Client::new(Http2Config::default()); + client.start_background_tasks(); + client.shutdown(); + // Shutdown again should be no-op + client.shutdown(); + } + + // --- 错误恢复:对不可达地址应返回连接错误 --- + + #[tokio::test] + async fn test_ask_connection_error() { + let client = + Http2Client::new(Http2Config::default().connect_timeout(Duration::from_millis(100))) + .with_retry(RetryConfig::no_retry()); + let addr: SocketAddr = "127.0.0.1:1".parse().unwrap(); + let result = client.ask(addr, "/actors/foo", "ping", vec![]).await; + let err = result.unwrap_err(); + let msg = err.to_string().to_lowercase(); + assert!( + msg.contains("connection") || msg.contains("refused") || msg.contains("reset"), + "expected connection-related error, got: {}", + msg + ); + } + + #[tokio::test] + async fn test_tell_connection_error() { + let client = + Http2Client::new(Http2Config::default().connect_timeout(Duration::from_millis(100))) + .with_retry(RetryConfig::no_retry()); + let addr: SocketAddr = "127.0.0.1:1".parse().unwrap(); + let result = client.tell(addr, "/actors/foo", "ping", vec![]).await; + let err = result.unwrap_err(); + let msg = err.to_string().to_lowercase(); + assert!( + msg.contains("connection") || msg.contains("refused") || msg.contains("reset"), + "expected connection-related error, got: {}", + msg + ); + } + + // --- 错误植入 --- + + #[tokio::test] + async fn test_fault_injector_ask() { + struct InjectAllAsk; + impl FaultInjector for InjectAllAsk { + fn inject(&self, ctx: &FaultInjectContext) -> Option { + if ctx.operation == FaultInjectOperation::Ask { + Some(PulsingError::from(RuntimeError::connection_failed( + ctx.addr.to_string(), + "injected".to_string(), + ))) + } else { + None + } + } + } + let client = Http2Client::new(Http2Config::default()) + .with_fault_injector(Some(Arc::new(InjectAllAsk))); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let err = client.ask(addr, "/p", "t", vec![]).await.unwrap_err(); + assert!(err.to_string().to_lowercase().contains("injected")); + } + + #[tokio::test] + async fn test_fault_injector_tell() { + struct InjectTell; + impl FaultInjector for InjectTell { + fn inject(&self, ctx: &FaultInjectContext) -> Option { + if ctx.operation == FaultInjectOperation::Tell { + Some(PulsingError::from(RuntimeError::request_timeout(1))) + } else { + None + } + } + } + let client = Http2Client::new(Http2Config::default()) + .with_fault_injector(Some(Arc::new(InjectTell))); + let addr: SocketAddr = "127.0.0.1:9999".parse().unwrap(); + let err = client.tell(addr, "/p", "t", vec![]).await.unwrap_err(); + assert!(err.to_string().to_lowercase().contains("timeout")); + } + + #[tokio::test] + async fn test_fault_injector_none_no_effect() { + let client = + Http2Client::new(Http2Config::default().connect_timeout(Duration::from_millis(50))) + .with_retry(RetryConfig::no_retry()) + .with_fault_injector(None); + let addr: SocketAddr = "127.0.0.1:1".parse().unwrap(); + let result = client.ask(addr, "/p", "t", vec![]).await; + let err = result.unwrap_err(); + assert!( + err.to_string().to_lowercase().contains("connection") + || err.to_string().to_lowercase().contains("refused") + ); + } } diff --git a/crates/pulsing-actor/src/transport/http2/config.rs b/crates/pulsing-actor/src/transport/http2/config.rs index 698411c53..f1ab6f9a8 100644 --- a/crates/pulsing-actor/src/transport/http2/config.rs +++ b/crates/pulsing-actor/src/transport/http2/config.rs @@ -1,5 +1,6 @@ //! HTTP/2 transport configuration. +use crate::error::Result; use std::time::Duration; #[cfg(feature = "tls")] @@ -180,7 +181,7 @@ impl Http2Config { /// The passphrase is used to derive a shared CA certificate, enabling /// automatic mutual TLS authentication. #[cfg(feature = "tls")] - pub fn with_tls(mut self, passphrase: &str) -> anyhow::Result { + pub fn with_tls(mut self, passphrase: &str) -> Result { self.tls = Some(TlsConfig::from_passphrase(passphrase)?); Ok(self) } diff --git a/crates/pulsing-actor/src/transport/http2/mod.rs b/crates/pulsing-actor/src/transport/http2/mod.rs index 25292d3d8..ac0460fab 100644 --- a/crates/pulsing-actor/src/transport/http2/mod.rs +++ b/crates/pulsing-actor/src/transport/http2/mod.rs @@ -7,12 +7,14 @@ mod retry; mod server; mod stream; -use crate::error::RuntimeError; +use crate::error::{PulsingError, Result, RuntimeError}; #[cfg(feature = "tls")] mod tls; -pub use client::{Http2Client, Http2ClientBuilder}; +pub use client::{ + FaultInjectContext, FaultInjectOperation, FaultInjector, Http2Client, Http2ClientBuilder, +}; pub use config::Http2Config; pub use pool::{ConnectionPool, PoolConfig, PoolStats}; pub use retry::{RetryConfig, RetryExecutor, RetryableError}; @@ -42,7 +44,7 @@ impl Http2Transport { handler: Arc, config: Http2Config, cancel: CancellationToken, - ) -> anyhow::Result<(Arc, SocketAddr)> { + ) -> crate::error::Result<(Arc, SocketAddr)> { let client = Arc::new(Http2Client::new(config.clone())); client.start_background_tasks(); @@ -79,12 +81,7 @@ impl Http2Transport { } /// Send a request to an actor and wait for response. - pub async fn ask( - &self, - addr: SocketAddr, - actor_name: &str, - msg: Message, - ) -> anyhow::Result { + pub async fn ask(&self, addr: SocketAddr, actor_name: &str, msg: Message) -> Result { let path = format!("/actors/{}", actor_name); self.client.send_message_full(addr, &path, msg).await } @@ -95,17 +92,12 @@ impl Http2Transport { addr: SocketAddr, path: &ActorPath, msg: Message, - ) -> anyhow::Result { + ) -> Result { let url_path = format!("/named/{}", path.as_str()); self.client.send_message_full(addr, &url_path, msg).await } - pub async fn tell( - &self, - addr: SocketAddr, - actor_name: &str, - msg: Message, - ) -> anyhow::Result<()> { + pub async fn tell(&self, addr: SocketAddr, actor_name: &str, msg: Message) -> Result<()> { let path = format!("/actors/{}", actor_name); let Message::Single { msg_type, data } = msg else { return Err(RuntimeError::protocol_error("Streaming not supported for tell").into()); @@ -114,12 +106,7 @@ impl Http2Transport { self.client.tell(addr, &path, &msg_type, data).await } - pub async fn tell_named( - &self, - addr: SocketAddr, - path: &ActorPath, - msg: Message, - ) -> anyhow::Result<()> { + pub async fn tell_named(&self, addr: SocketAddr, path: &ActorPath, msg: Message) -> Result<()> { let url_path = format!("/named/{}", path.as_str()); let Message::Single { msg_type, data } = msg else { return Err(RuntimeError::protocol_error("Streaming not supported for tell").into()); @@ -129,11 +116,7 @@ impl Http2Transport { } /// Send a gossip message - pub async fn send_gossip( - &self, - addr: SocketAddr, - payload: Vec, - ) -> anyhow::Result>> { + pub async fn send_gossip(&self, addr: SocketAddr, payload: Vec) -> Result>> { let response = self .client .ask(addr, "/cluster/gossip", "gossip", payload) @@ -349,87 +332,67 @@ impl RemoteTransport for Http2RemoteTransport { _actor_id: &ActorId, msg_type: &str, payload: Vec, - ) -> anyhow::Result> { - // Check circuit breaker before making request + ) -> Result> { if !self.circuit_breaker.can_execute() { - return Err(RuntimeError::ConnectionFailed { + return Err(PulsingError::from(RuntimeError::ConnectionFailed { addr: self.remote_addr.to_string(), reason: "Circuit breaker is open".to_string(), - } - .into()); + })); } let result = self .client .ask(self.remote_addr, &self.path, msg_type, payload) - .await; + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string()))); - // Record outcome in circuit breaker self.circuit_breaker.record_outcome(result.is_ok()); result } - async fn send( - &self, - _actor_id: &ActorId, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result<()> { - // Check circuit breaker before making request + async fn send(&self, _actor_id: &ActorId, msg_type: &str, payload: Vec) -> Result<()> { if !self.circuit_breaker.can_execute() { - return Err(RuntimeError::ConnectionFailed { + return Err(PulsingError::from(RuntimeError::ConnectionFailed { addr: self.remote_addr.to_string(), reason: "Circuit breaker is open".to_string(), - } - .into()); + })); } let result = self .client .tell(self.remote_addr, &self.path, msg_type, payload) - .await; + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string()))); - // Record outcome in circuit breaker self.circuit_breaker.record_outcome(result.is_ok()); result } /// Send a message and receive response (unified interface) - /// - /// This method is the primary way ActorRef communicates with remote actors. - /// It automatically handles both: - /// - Single and streaming requests (based on Message type) - /// - Single and streaming responses (based on server's response type header) - async fn send_message(&self, _actor_id: &ActorId, msg: Message) -> anyhow::Result { - // Check circuit breaker before making request + async fn send_message(&self, _actor_id: &ActorId, msg: Message) -> Result { if !self.circuit_breaker.can_execute() { - return Err(RuntimeError::ConnectionFailed { + return Err(PulsingError::from(RuntimeError::ConnectionFailed { addr: self.remote_addr.to_string(), reason: "Circuit breaker is open".to_string(), - } - .into()); + })); } - // Use unified send_message_full that handles both single and streaming let result = self .client .send_message_full(self.remote_addr, &self.path, msg) - .await; + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string()))); - // Record outcome in circuit breaker self.circuit_breaker.record_outcome(result.is_ok()); result } /// Send a one-way message (unified interface) - /// - /// Note: Streaming requests are NOT supported for fire-and-forget messages - /// because there's no way to know when the stream is fully consumed. - async fn send_oneway(&self, actor_id: &ActorId, msg: Message) -> anyhow::Result<()> { + async fn send_oneway(&self, actor_id: &ActorId, msg: Message) -> Result<()> { let Message::Single { msg_type, data } = msg else { - return Err(anyhow::anyhow!( - "Streaming not supported for fire-and-forget (use ask pattern instead)" - )); + return Err(PulsingError::from(RuntimeError::Other( + "Streaming not supported for fire-and-forget (use ask pattern instead)".into(), + ))); }; self.send(actor_id, &msg_type, data).await } diff --git a/crates/pulsing-actor/src/transport/http2/pool.rs b/crates/pulsing-actor/src/transport/http2/pool.rs index 50aa509f4..689e93d85 100644 --- a/crates/pulsing-actor/src/transport/http2/pool.rs +++ b/crates/pulsing-actor/src/transport/http2/pool.rs @@ -1,6 +1,7 @@ //! Connection pool management for HTTP/2 transport. use super::config::Http2Config; +use crate::error::{PulsingError, Result, RuntimeError}; use bytes::Bytes; use http_body_util::Full; use hyper::client::conn::http2; @@ -235,7 +236,7 @@ impl ConnectionPool { } /// Get or create a connection to the given address - pub async fn get_connection(&self, addr: SocketAddr) -> anyhow::Result { + pub async fn get_connection(&self, addr: SocketAddr) -> Result { // Try to get an existing healthy connection first if let Some(conn) = self.try_get_existing(addr).await { self.stats @@ -245,7 +246,12 @@ impl ConnectionPool { } // Create a new connection - self.create_new_connection(addr).await + self.create_new_connection(addr).await.map_err(|e| { + PulsingError::from(RuntimeError::connection_failed( + addr.to_string(), + e.to_string(), + )) + }) } /// Try to get an existing healthy connection @@ -362,7 +368,10 @@ impl ConnectionPool { if let Some(ref tls_config) = self.http2_config.tls { // TLS mode: wrap TCP stream with TLS let server_name = addr.ip().to_string(); - let tls_stream = tls_config.connect(stream, &server_name).await?; + let tls_stream = tls_config + .connect(stream, &server_name) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; let io = TokioIo::new(tls_stream); let (sender, conn) = http2::handshake(TokioExecutor::new(), io) .await diff --git a/crates/pulsing-actor/src/transport/http2/retry.rs b/crates/pulsing-actor/src/transport/http2/retry.rs index 19b8bf20b..77c1bce6a 100644 --- a/crates/pulsing-actor/src/transport/http2/retry.rs +++ b/crates/pulsing-actor/src/transport/http2/retry.rs @@ -1,5 +1,6 @@ //! Retry and timeout strategies for HTTP/2 transport. +use crate::error::{PulsingError, Result}; use std::time::Duration; /// Retry configuration. @@ -115,13 +116,19 @@ pub enum RetryableError { } impl RetryableError { - pub fn classify(error: &anyhow::Error) -> Self { + /// Classify PulsingError for retry decision (preferred). + pub fn classify_pulsing(error: &PulsingError) -> Self { let msg = error.to_string().to_lowercase(); + Self::classify_msg(&msg) + } + + /// Classify by error message string (used for anyhow or PulsingError). + fn classify_msg(msg: &str) -> Self { + let msg = msg.to_lowercase(); if msg.contains("backing off") { return Self::Unknown; } - if msg.contains("connection") || msg.contains("connect") || msg.contains("refused") @@ -129,15 +136,12 @@ impl RetryableError { { return Self::Connection; } - if msg.contains("timeout") || msg.contains("timed out") { return Self::Timeout; } - if msg.contains("503") || msg.contains("service unavailable") { return Self::ServerOverloaded; } - if msg.contains("500") || msg.contains("502") || msg.contains("504") @@ -147,7 +151,6 @@ impl RetryableError { { return Self::ServerError; } - if msg.contains("400") || msg.contains("401") || msg.contains("403") @@ -159,10 +162,14 @@ impl RetryableError { { return Self::ClientError; } - Self::Unknown } + /// Classify anyhow::Error (for compatibility with code still using anyhow). + pub fn classify(error: &anyhow::Error) -> Self { + Self::classify_msg(&error.to_string()) + } + pub fn is_retryable(&self, idempotent_only: bool, is_idempotent: bool) -> bool { match self { // Connection errors are always retryable @@ -191,16 +198,15 @@ impl RetryExecutor { Self { config } } - /// Execute a function with retry logic - pub async fn execute(&self, is_idempotent: bool, mut f: F) -> anyhow::Result + /// Execute a function with retry logic (unified error type). + pub async fn execute(&self, is_idempotent: bool, mut f: F) -> Result where F: FnMut() -> Fut, - Fut: std::future::Future>, + Fut: std::future::Future>, { - let mut last_error = None; + let mut last_error: Option = None; for attempt in 0..=self.config.max_retries { - // Wait before retry (except for first attempt) if attempt > 0 { let delay = self.config.delay_for_attempt(attempt); tracing::debug!( @@ -214,7 +220,7 @@ impl RetryExecutor { match f().await { Ok(result) => return Ok(result), Err(e) => { - let error_type = RetryableError::classify(&e); + let error_type = RetryableError::classify_pulsing(&e); let should_retry = attempt < self.config.max_retries && error_type.is_retryable(self.config.idempotent_only, is_idempotent); @@ -234,7 +240,11 @@ impl RetryExecutor { } } - Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Max retries exceeded"))) + Err(last_error.unwrap_or_else(|| { + PulsingError::from(crate::error::RuntimeError::Other( + "Max retries exceeded".to_string(), + )) + })) } } @@ -328,7 +338,7 @@ mod tests { async fn test_retry_executor_success() { let executor = RetryExecutor::new(RetryConfig::with_max_retries(3)); let result = executor - .execute(true, || async { Ok::<_, anyhow::Error>(42) }) + .execute(true, || async { Ok::<_, PulsingError>(42) }) .await; assert_eq!(result.unwrap(), 42); } @@ -347,7 +357,12 @@ mod tests { async move { let n = count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); if n < 2 { - Err(anyhow::anyhow!("Connection refused")) + Err(PulsingError::from( + crate::error::RuntimeError::connection_failed( + "127.0.0.1:1", + "Connection refused", + ), + )) } else { Ok(42) } @@ -358,4 +373,20 @@ mod tests { assert_eq!(result.unwrap(), 42); assert_eq!(counter.load(std::sync::atomic::Ordering::SeqCst), 3); } + + #[test] + fn test_error_classification_pulsing() { + use crate::error::RuntimeError; + let conn_err = + PulsingError::from(RuntimeError::connection_failed("127.0.0.1:1", "refused")); + assert_eq!( + RetryableError::classify_pulsing(&conn_err), + RetryableError::Connection + ); + let timeout_err = PulsingError::from(RuntimeError::request_timeout(5000)); + assert_eq!( + RetryableError::classify_pulsing(&timeout_err), + RetryableError::Timeout + ); + } } diff --git a/crates/pulsing-actor/src/transport/http2/server.rs b/crates/pulsing-actor/src/transport/http2/server.rs index 7f853645c..49a83ac40 100644 --- a/crates/pulsing-actor/src/transport/http2/server.rs +++ b/crates/pulsing-actor/src/transport/http2/server.rs @@ -4,6 +4,7 @@ use super::config::Http2Config; use super::stream::{BinaryFrameParser, StreamFrame}; use super::{headers, MessageMode, RequestType}; use crate::actor::Message; +use crate::error::{PulsingError, Result, RuntimeError}; use crate::tracing::{TraceContext, TRACEPARENT_HEADER}; use bytes::Bytes; use futures::StreamExt; @@ -25,14 +26,14 @@ use tokio_util::sync::CancellationToken; #[async_trait::async_trait] pub trait Http2ServerHandler: Send + Sync + 'static { /// Unified message handler. - async fn handle_message_full(&self, path: &str, msg: Message) -> anyhow::Result { + async fn handle_message_full(&self, path: &str, msg: Message) -> Result { match msg { Message::Single { msg_type, data } => { self.handle_message_simple(path, &msg_type, data).await } - Message::Stream { .. } => Err(anyhow::anyhow!( - "Streaming requests not supported by this handler" - )), + Message::Stream { .. } => Err(PulsingError::from(RuntimeError::Other( + "Streaming requests not supported by this handler".into(), + ))), } } @@ -42,21 +43,22 @@ pub trait Http2ServerHandler: Send + Sync + 'static { path: &str, msg_type: &str, payload: Vec, - ) -> anyhow::Result { + ) -> Result { let _ = (path, msg_type, payload); - Err(anyhow::anyhow!("Not implemented")) + Err(PulsingError::from(RuntimeError::Other( + "Not implemented".into(), + ))) } /// Handle tell (fire-and-forget) message. - async fn handle_tell(&self, path: &str, msg_type: &str, payload: Vec) - -> anyhow::Result<()>; + async fn handle_tell(&self, path: &str, msg_type: &str, payload: Vec) -> Result<()>; /// Handle gossip message. async fn handle_gossip( &self, payload: Vec, peer_addr: SocketAddr, - ) -> anyhow::Result>>; + ) -> Result>>; /// Get health status. async fn health_check(&self) -> serde_json::Value { @@ -88,7 +90,7 @@ pub trait Http2ServerHandler: Send + Sync + 'static { _path: &str, _method: &str, _body: Vec, - ) -> anyhow::Result>> { + ) -> Result>> { Ok(None) } } @@ -105,9 +107,13 @@ impl Http2Server { handler: Arc, config: Http2Config, cancel: CancellationToken, - ) -> anyhow::Result { - let listener = TcpListener::bind(bind_addr).await?; - let local_addr = listener.local_addr()?; + ) -> Result { + let listener = TcpListener::bind(bind_addr) + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; + let local_addr = listener + .local_addr() + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; tracing::info!(addr = %local_addr, "Starting HTTP/2 server"); @@ -183,7 +189,10 @@ impl Http2Server { #[cfg(feature = "tls")] if let Some(ref tls_config) = config.tls { // TLS mode: accept TLS handshake first - let tls_stream = tls_config.accept(stream).await?; + let tls_stream = tls_config + .accept(stream) + .await + .map_err(|e| anyhow::anyhow!("{}", e))?; let io = TokioIo::new(tls_stream); // TLS connections always use HTTP/2 (no HTTP/1.1 fallback) @@ -251,7 +260,7 @@ impl Http2Server { req: Request, handler: Arc, peer_addr: SocketAddr, - ) -> Result, Infallible> { + ) -> std::result::Result, Infallible> { let path = req.uri().path().to_string(); let method = req.method().clone(); @@ -438,36 +447,31 @@ impl Http2Server { /// Parse a streaming request body (binary frames) into Message::Stream fn parse_streaming_request(body: Incoming, default_msg_type: &str) -> Message { - let (tx, rx) = mpsc::channel::>(32); + let (tx, rx) = mpsc::channel::>(32); let default_msg_type = default_msg_type.to_string(); - // Spawn task to parse binary frames tokio::spawn(async move { let mut parser = BinaryFrameParser::new(); let mut body_stream = http_body_util::BodyStream::new(body); - while let Some(result) = body_stream.next().await { match result { Ok(frame) => { if let Ok(data) = frame.into_data() { parser.push(&data); - // Parse all complete frames for frame_result in parser.parse_all() { match frame_result { Ok(frame) => { - // Skip end frames if frame.end && frame.get_data().is_empty() { continue; } - // Convert frame to message match frame.to_message() { Ok(Some(msg)) => { if tx.send(Ok(msg)).await.is_err() { return; } } - Ok(None) => {} // Skip empty frames + Ok(None) => {} Err(e) => { let _ = tx.send(Err(e)).await; return; @@ -475,7 +479,11 @@ impl Http2Server { } } Err(e) => { - let _ = tx.send(Err(e)).await; + let _ = tx + .send(Err(PulsingError::from(RuntimeError::Other( + e.to_string(), + )))) + .await; return; } } @@ -484,7 +492,7 @@ impl Http2Server { } Err(e) => { let _ = tx - .send(Err(anyhow::anyhow!("Body read error: {}", e))) + .send(Err(PulsingError::from(RuntimeError::Other(e.to_string())))) .await; return; } @@ -500,7 +508,7 @@ impl Http2Server { handler: &Arc, payload: Vec, peer_addr: SocketAddr, - ) -> Result, Infallible> { + ) -> std::result::Result, Infallible> { match handler.handle_gossip(payload, peer_addr).await { Ok(Some(response)) => Ok(octet_response(StatusCode::OK, response)), Ok(None) => Ok(empty_response(StatusCode::OK)), @@ -516,7 +524,7 @@ impl Http2Server { handler: &Arc, path: &str, msg: Message, - ) -> Result, Infallible> { + ) -> std::result::Result, Infallible> { match handler.handle_message_full(path, msg).await { Ok(Message::Single { data, .. }) => { // Single response - return directly with response type header @@ -531,7 +539,7 @@ impl Http2Server { stream, }) => { // Stream response - convert Message stream to binary frames - let (tx, rx) = mpsc::channel::, Infallible>>(32); + let (tx, rx) = mpsc::channel::, Infallible>>(32); tokio::spawn(async move { let mut stream = std::pin::pin!(stream); @@ -574,7 +582,7 @@ impl Http2Server { path: &str, msg_type: &str, payload: Vec, - ) -> Result, Infallible> { + ) -> std::result::Result, Infallible> { match handler.handle_tell(path, msg_type, payload).await { Ok(()) => Ok(empty_response(StatusCode::ACCEPTED)), Err(e) => Ok(error_response( @@ -696,17 +704,11 @@ mod tests { _path: &str, _msg_type: &str, payload: Vec, - ) -> anyhow::Result { - // Echo the payload as single message + ) -> Result { Ok(Message::single("", payload)) } - async fn handle_tell( - &self, - _path: &str, - _msg_type: &str, - _payload: Vec, - ) -> anyhow::Result<()> { + async fn handle_tell(&self, _path: &str, _msg_type: &str, _payload: Vec) -> Result<()> { Ok(()) } @@ -714,7 +716,7 @@ mod tests { &self, _payload: Vec, _peer_addr: SocketAddr, - ) -> anyhow::Result>> { + ) -> Result>> { Ok(None) } diff --git a/crates/pulsing-actor/src/transport/http2/stream.rs b/crates/pulsing-actor/src/transport/http2/stream.rs index ebd6ed328..31cf4738a 100644 --- a/crates/pulsing-actor/src/transport/http2/stream.rs +++ b/crates/pulsing-actor/src/transport/http2/stream.rs @@ -1,6 +1,7 @@ //! Streaming support for HTTP/2 transport. use crate::actor::Message; +use crate::error::{PulsingError, Result, RuntimeError}; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures::Stream; use std::pin::Pin; @@ -75,9 +76,11 @@ impl StreamFrame { } } - pub fn to_message(&self) -> anyhow::Result> { + pub fn to_message(&self) -> crate::error::Result> { if let Some(ref error) = self.error { - return Err(anyhow::anyhow!("{}", error)); + return Err(crate::error::PulsingError::from( + crate::error::RuntimeError::Other(error.clone()), + )); } if self.end && self.data.is_empty() { return Ok(None); @@ -122,17 +125,19 @@ impl StreamFrame { buf.freeze() } - pub fn from_binary(mut buf: &[u8]) -> anyhow::Result { + pub fn from_binary(mut buf: &[u8]) -> Result { if buf.remaining() < 4 { - return Err(anyhow::anyhow!("Buffer too short for length")); + return Err(PulsingError::from(RuntimeError::Other( + "Buffer too short for length".into(), + ))); } let total_len = buf.get_u32() as usize; if buf.remaining() < total_len { - return Err(anyhow::anyhow!( + return Err(PulsingError::from(RuntimeError::Other(format!( "Incomplete frame: expected {} bytes", total_len - )); + )))); } let flags = buf.get_u8(); @@ -141,18 +146,28 @@ impl StreamFrame { let msg_type_len = buf.get_u16() as usize; if buf.remaining() < msg_type_len { - return Err(anyhow::anyhow!("Invalid msg_type length")); + return Err(PulsingError::from(RuntimeError::Other( + "Invalid msg_type length".into(), + ))); } - let msg_type = String::from_utf8(buf[..msg_type_len].to_vec()) - .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in msg_type: {}", e))?; + let msg_type = String::from_utf8(buf[..msg_type_len].to_vec()).map_err(|e| { + PulsingError::from(RuntimeError::Other(format!( + "Invalid UTF-8 in msg_type: {}", + e + ))) + })?; buf.advance(msg_type_len); if buf.remaining() < 4 { - return Err(anyhow::anyhow!("Missing data length")); + return Err(PulsingError::from(RuntimeError::Other( + "Missing data length".into(), + ))); } let data_len = buf.get_u32() as usize; if buf.remaining() < data_len { - return Err(anyhow::anyhow!("Invalid data length")); + return Err(PulsingError::from(RuntimeError::Other( + "Invalid data length".into(), + ))); } let data = buf[..data_len].to_vec(); buf.advance(data_len); @@ -161,10 +176,12 @@ impl StreamFrame { let error = if has_error && buf.remaining() >= 2 { let error_len = buf.get_u16() as usize; if buf.remaining() >= error_len { - Some( - String::from_utf8(buf[..error_len].to_vec()) - .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in error: {}", e))?, - ) + Some(String::from_utf8(buf[..error_len].to_vec()).map_err(|e| { + PulsingError::from(RuntimeError::Other(format!( + "Invalid UTF-8 in error: {}", + e + ))) + })?) } else { None } @@ -256,7 +273,7 @@ impl BinaryFrameParser { } /// Try to parse a complete frame from the buffer - pub fn try_parse(&mut self) -> Option> { + pub fn try_parse(&mut self) -> Option> { if self.buffer.len() < 4 { return None; } @@ -280,7 +297,7 @@ impl BinaryFrameParser { } /// Parse all available frames - pub fn parse_all(&mut self) -> Vec> { + pub fn parse_all(&mut self) -> Vec> { let mut frames = Vec::new(); while let Some(result) = self.try_parse() { frames.push(result); diff --git a/crates/pulsing-actor/src/transport/http2/tls.rs b/crates/pulsing-actor/src/transport/http2/tls.rs index 1ebcaf7ce..c8a6b4d9e 100644 --- a/crates/pulsing-actor/src/transport/http2/tls.rs +++ b/crates/pulsing-actor/src/transport/http2/tls.rs @@ -1,5 +1,6 @@ //! TLS support for HTTP/2 transport (passphrase-derived certificates). +use crate::error::{PulsingError, Result, RuntimeError}; use rcgen::{ BasicConstraints, Certificate, CertificateParams, DnType, ExtendedKeyUsagePurpose, IsCa, KeyPair, KeyUsagePurpose, SerialNumber, PKCS_ED25519, @@ -39,12 +40,14 @@ pub struct TlsConfig { impl TlsConfig { /// Create TLS configuration from a passphrase. - pub fn from_passphrase(passphrase: &str) -> anyhow::Result { + pub fn from_passphrase(passphrase: &str) -> Result { ensure_crypto_provider(); - let (ca_cert, ca_key_pair) = derive_ca_from_passphrase(passphrase)?; + let (ca_cert, ca_key_pair) = derive_ca_from_passphrase(passphrase) + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; - let (node_cert, node_key_pair) = generate_node_cert(&ca_cert, &ca_key_pair)?; + let (node_cert, node_key_pair) = generate_node_cert(&ca_cert, &ca_key_pair) + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; let ca_cert_der = CertificateDer::from(ca_cert.der().to_vec()); let node_cert_der = CertificateDer::from(node_cert.der().to_vec()); @@ -52,21 +55,23 @@ impl TlsConfig { PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(node_key_pair.serialize_der())); let mut root_store = RootCertStore::empty(); - root_store.add(ca_cert_der.clone())?; + root_store + .add(ca_cert_der.clone()) + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; let client_verifier = WebPkiClientVerifier::builder(Arc::new(root_store.clone())) .build() - .map_err(|e| anyhow::anyhow!("Failed to build client verifier: {}", e))?; + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; let server_config = ServerConfig::builder() .with_client_cert_verifier(client_verifier) .with_single_cert(vec![node_cert_der.clone()], node_key_der.clone_key()) - .map_err(|e| anyhow::anyhow!("Failed to build server config: {}", e))?; + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; let client_config = ClientConfig::builder() .with_root_certificates(root_store) .with_client_auth_cert(vec![node_cert_der], node_key_der) - .map_err(|e| anyhow::anyhow!("Failed to build client config: {}", e))?; + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; let hash = Sha256::digest(passphrase.as_bytes()); let hash_slice = hash.as_slice(); @@ -89,24 +94,24 @@ impl TlsConfig { &self, stream: tokio::net::TcpStream, _server_name: &str, - ) -> anyhow::Result> { + ) -> Result> { let server_name = ServerName::try_from("pulsing.internal".to_string()) - .map_err(|e| anyhow::anyhow!("Invalid server name: {}", e))?; + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string())))?; self.connector .connect(server_name, stream) .await - .map_err(|e| anyhow::anyhow!("TLS connect failed: {}", e)) + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string()))) } pub async fn accept( &self, stream: tokio::net::TcpStream, - ) -> anyhow::Result> { + ) -> Result> { self.acceptor .accept(stream) .await - .map_err(|e| anyhow::anyhow!("TLS accept failed: {}", e)) + .map_err(|e| PulsingError::from(RuntimeError::tls_error(e.to_string()))) } } diff --git a/crates/pulsing-actor/tests/cluster/mod.rs b/crates/pulsing-actor/tests/cluster/mod.rs deleted file mode 100644 index 8a2627ba0..000000000 --- a/crates/pulsing-actor/tests/cluster/mod.rs +++ /dev/null @@ -1,10 +0,0 @@ -//! Cluster module tests -//! -//! This module contains tests for cluster-related functionality: -//! - Member management (member_tests) -//! - Naming backend (naming_tests) -//! - Gossip protocol (gossip_tests) -//! - SWIM protocol (swim_tests) - -mod member_tests; -mod naming_tests; diff --git a/crates/pulsing-actor/tests/common/fixtures.rs b/crates/pulsing-actor/tests/common/fixtures.rs new file mode 100644 index 000000000..e04cb1d30 --- /dev/null +++ b/crates/pulsing-actor/tests/common/fixtures.rs @@ -0,0 +1,229 @@ +//! Shared test fixtures for transport tests +#![allow(dead_code)] + +use pulsing_actor::actor::Message; +use pulsing_actor::transport::{Http2Config, Http2Server, Http2ServerHandler}; +use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +/// Counter for tracking handler calls +#[derive(Default)] +pub struct TestCounters { + pub ask_count: AtomicUsize, + pub tell_count: AtomicUsize, +} + +/// Test handler that echoes messages and counts calls +pub struct TestHandler { + pub counters: Arc, +} + +impl TestHandler { + pub fn new() -> Self { + Self { + counters: Arc::new(TestCounters::default()), + } + } + + pub fn with_counters(counters: Arc) -> Self { + Self { counters } + } +} + +#[async_trait::async_trait] +impl Http2ServerHandler for TestHandler { + async fn handle_message_simple( + &self, + path: &str, + msg_type: &str, + payload: Vec, + ) -> pulsing_actor::error::Result { + self.counters.ask_count.fetch_add(1, Ordering::SeqCst); + + // Echo the payload with path and msg_type prepended + let mut response = format!("{}:{}:", path, msg_type).into_bytes(); + response.extend(payload); + Ok(Message::single("", response)) + } + + async fn handle_tell( + &self, + _path: &str, + _msg_type: &str, + _payload: Vec, + ) -> pulsing_actor::error::Result<()> { + self.counters.tell_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn handle_gossip( + &self, + _payload: Vec, + _peer_addr: std::net::SocketAddr, + ) -> pulsing_actor::error::Result>> { + Ok(None) + } + + async fn health_check(&self) -> serde_json::Value { + serde_json::json!({ + "status": "healthy", + "ask_count": self.counters.ask_count.load(Ordering::SeqCst), + "tell_count": self.counters.tell_count.load(Ordering::SeqCst), + }) + } + + async fn cluster_members(&self) -> serde_json::Value { + serde_json::json!([ + { + "node_id": "12345", + "addr": "127.0.0.1:8001", + "status": "Alive" + }, + { + "node_id": "67890", + "addr": "127.0.0.1:8002", + "status": "Alive" + } + ]) + } + + async fn actors_list(&self, include_internal: bool) -> serde_json::Value { + let mut actors = vec![ + serde_json::json!({ + "name": "counter-1", + "type": "user", + "actor_id": "12345:1", + "class": "Counter", + "module": "__main__" + }), + serde_json::json!({ + "name": "calculator", + "type": "user", + "actor_id": "12345:2", + "class": "Calculator", + "module": "__main__" + }), + ]; + + if include_internal { + actors.push(serde_json::json!({ + "name": "_python_actor_service", + "type": "system", + "actor_id": "12345:0" + })); + } + + serde_json::json!(actors) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Handler that supports streaming requests +pub struct StreamingHandler { + pub counters: Arc, +} + +impl StreamingHandler { + pub fn with_counters(counters: Arc) -> Self { + Self { counters } + } +} + +#[async_trait::async_trait] +impl Http2ServerHandler for StreamingHandler { + async fn handle_message_full( + &self, + path: &str, + msg: Message, + ) -> pulsing_actor::error::Result { + use futures::StreamExt; + self.counters.ask_count.fetch_add(1, Ordering::SeqCst); + + match msg { + Message::Single { msg_type, data } => { + // Echo single message + let mut response = format!("{}:{}:", path, msg_type).into_bytes(); + response.extend(data); + Ok(Message::single("echo", response)) + } + Message::Stream { stream, .. } => { + // Collect all stream data and echo back as single response + let mut collected = Vec::new(); + let mut stream = std::pin::pin!(stream); + while let Some(result) = stream.next().await { + match result { + Ok(Message::Single { data, .. }) => { + collected.extend(data); + } + Ok(Message::Stream { .. }) => { + return Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other( + "Nested streams not supported".into(), + ), + )); + } + Err(e) => return Err(e), + } + } + let response = format!("{}:collected:{} bytes", path, collected.len()).into_bytes(); + Ok(Message::single("stream_echo", response)) + } + } + } + + async fn handle_message_simple( + &self, + path: &str, + msg_type: &str, + payload: Vec, + ) -> pulsing_actor::error::Result { + self.counters.ask_count.fetch_add(1, Ordering::SeqCst); + let mut response = format!("{}:{}:", path, msg_type).into_bytes(); + response.extend(payload); + Ok(Message::single("", response)) + } + + async fn handle_tell( + &self, + _path: &str, + _msg_type: &str, + _payload: Vec, + ) -> pulsing_actor::error::Result<()> { + self.counters.tell_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn handle_gossip( + &self, + _payload: Vec, + _peer_addr: std::net::SocketAddr, + ) -> pulsing_actor::error::Result>> { + Ok(None) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +/// Helper to create a test server with handler and return (server, counters, cancel_token) +pub async fn create_test_server() -> (Http2Server, Arc, CancellationToken) { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + (server, counters, cancel) +} diff --git a/crates/pulsing-actor/tests/common/mod.rs b/crates/pulsing-actor/tests/common/mod.rs new file mode 100644 index 000000000..d066349cc --- /dev/null +++ b/crates/pulsing-actor/tests/common/mod.rs @@ -0,0 +1 @@ +pub mod fixtures; diff --git a/crates/pulsing-actor/tests/http2_transport_tests.rs b/crates/pulsing-actor/tests/http2_transport_tests.rs deleted file mode 100644 index 864ac521f..000000000 --- a/crates/pulsing-actor/tests/http2_transport_tests.rs +++ /dev/null @@ -1,1684 +0,0 @@ -//! HTTP/2 Transport layer tests -//! -//! Tests for HTTP/2 (h2c) transport including: -//! - Server creation and startup -//! - Client connection and requests -//! - Ask (request-response) pattern -//! - Tell (fire-and-forget) pattern -//! - Http2RemoteTransport integration - -use pulsing_actor::actor::{ActorId, Message}; -use pulsing_actor::transport::{ - Http2Client, Http2Config, Http2RemoteTransport, Http2Server, Http2ServerHandler, -}; -use std::sync::atomic::{AtomicUsize, Ordering}; -use std::sync::Arc; -use tokio_util::sync::CancellationToken; - -// ============================================================================ -// Test Handler Implementation -// ============================================================================ - -/// Counter for tracking handler calls -#[derive(Default)] -struct TestCounters { - ask_count: AtomicUsize, - tell_count: AtomicUsize, -} - -/// Test handler that echoes messages and counts calls -struct TestHandler { - counters: Arc, -} - -impl TestHandler { - fn new() -> Self { - Self { - counters: Arc::new(TestCounters::default()), - } - } - - fn with_counters(counters: Arc) -> Self { - Self { counters } - } -} - -#[async_trait::async_trait] -impl Http2ServerHandler for TestHandler { - async fn handle_message_simple( - &self, - path: &str, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result { - self.counters.ask_count.fetch_add(1, Ordering::SeqCst); - - // Echo the payload with path and msg_type prepended - let mut response = format!("{}:{}:", path, msg_type).into_bytes(); - response.extend(payload); - Ok(Message::single("", response)) - } - - async fn handle_tell( - &self, - _path: &str, - _msg_type: &str, - _payload: Vec, - ) -> anyhow::Result<()> { - self.counters.tell_count.fetch_add(1, Ordering::SeqCst); - Ok(()) - } - - async fn handle_gossip( - &self, - _payload: Vec, - _peer_addr: std::net::SocketAddr, - ) -> anyhow::Result>> { - Ok(None) - } - - async fn health_check(&self) -> serde_json::Value { - serde_json::json!({ - "status": "healthy", - "ask_count": self.counters.ask_count.load(Ordering::SeqCst), - "tell_count": self.counters.tell_count.load(Ordering::SeqCst), - }) - } - - async fn cluster_members(&self) -> serde_json::Value { - serde_json::json!([ - { - "node_id": "12345", - "addr": "127.0.0.1:8001", - "status": "Alive" - }, - { - "node_id": "67890", - "addr": "127.0.0.1:8002", - "status": "Alive" - } - ]) - } - - async fn actors_list(&self, include_internal: bool) -> serde_json::Value { - let mut actors = vec![ - serde_json::json!({ - "name": "counter-1", - "type": "user", - "actor_id": "12345:1", - "class": "Counter", - "module": "__main__" - }), - serde_json::json!({ - "name": "calculator", - "type": "user", - "actor_id": "12345:2", - "class": "Calculator", - "module": "__main__" - }), - ]; - - if include_internal { - actors.push(serde_json::json!({ - "name": "_python_actor_service", - "type": "system", - "actor_id": "12345:0" - })); - } - - serde_json::json!(actors) - } - - fn as_any(&self) -> &dyn std::any::Any { - self - } -} - -// ============================================================================ -// Server Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_server_creation() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - // Server should be bound to a valid port - assert_ne!(server.local_addr().port(), 0); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_server_multiple_instances() { - let cancel = CancellationToken::new(); - - // Create multiple servers - let mut servers = Vec::new(); - for _ in 0..3 { - let handler = Arc::new(TestHandler::new()); - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - servers.push(server); - } - - // All servers should have different ports - let ports: Vec = servers.iter().map(|s| s.local_addr().port()).collect(); - let unique_ports: std::collections::HashSet<_> = ports.iter().collect(); - assert_eq!(ports.len(), unique_ports.len()); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Client Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_client_creation() { - let client = Http2Client::new(Http2Config::default()); - // Client should be clonable - let _cloned = client.clone(); -} - -#[tokio::test] -async fn test_http2_ask_request() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client and send request - let client = Http2Client::new(Http2Config::default()); - let response = client - .ask(addr, "/actors/test", "TestMsg", b"hello".to_vec()) - .await - .unwrap(); - - // Verify response - let response_str = String::from_utf8_lossy(&response); - assert!(response_str.contains("/actors/test")); - assert!(response_str.contains("TestMsg")); - assert!(response_str.contains("hello")); - - // Verify counter - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_tell_request() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client and send tell - let client = Http2Client::new(Http2Config::default()); - client - .tell(addr, "/actors/test", "FireAndForget", b"data".to_vec()) - .await - .unwrap(); - - // Give the server time to process - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - - // Verify counter - assert_eq!(counters.tell_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_multiple_requests() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::default()); - - // Send multiple requests - for i in 0..10 { - let response = client - .ask( - addr, - "/actors/test", - "Msg", - format!("request-{}", i).into_bytes(), - ) - .await - .unwrap(); - - let response_str = String::from_utf8_lossy(&response); - assert!(response_str.contains(&format!("request-{}", i))); - } - - // Verify counter - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 10); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_concurrent_requests() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Arc::new(Http2Client::new(Http2Config::default())); - - // Send concurrent requests - let mut handles = Vec::new(); - for i in 0..20 { - let client = client.clone(); - let handle = tokio::spawn(async move { - client - .ask( - addr, - "/actors/test", - "Concurrent", - format!("req-{}", i).into_bytes(), - ) - .await - }); - handles.push(handle); - } - - // Wait for all requests - let results: Vec<_> = futures::future::join_all(handles).await; - - // All should succeed - for result in results { - assert!(result.unwrap().is_ok()); - } - - // Verify counter - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 20); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Http2RemoteTransport Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_remote_transport_ask() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create transport - let client = Arc::new(Http2Client::new(Http2Config::default())); - let transport = Http2RemoteTransport::new(client, addr, "test-actor".to_string()); - - // Use the RemoteTransport trait - use pulsing_actor::actor::RemoteTransport; - - let actor_id = ActorId::generate(); - let response = transport - .request(&actor_id, "TestType", b"payload".to_vec()) - .await - .unwrap(); - - // Verify response contains expected data - let response_str = String::from_utf8_lossy(&response); - assert!(response_str.contains("test-actor")); - assert!(response_str.contains("TestType")); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_remote_transport_tell() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create transport - let client = Arc::new(Http2Client::new(Http2Config::default())); - let transport = Http2RemoteTransport::new(client, addr, "fire-actor".to_string()); - - // Use the RemoteTransport trait - use pulsing_actor::actor::RemoteTransport; - - let actor_id = ActorId::generate(); - transport - .send(&actor_id, "FireMsg", b"data".to_vec()) - .await - .unwrap(); - - // Give the server time to process - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; - - // Verify counter - assert_eq!(counters.tell_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_remote_transport_named_path() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create transport with named path - let client = Arc::new(Http2Client::new(Http2Config::default())); - use pulsing_actor::actor::ActorPath; - let path = ActorPath::new("services/llm/worker").unwrap(); - let transport = Http2RemoteTransport::new_named(client, addr, path); - - // Use the RemoteTransport trait - use pulsing_actor::actor::RemoteTransport; - - let actor_id = ActorId::generate(); - let response = transport - .request(&actor_id, "Inference", b"prompt".to_vec()) - .await - .unwrap(); - - // Verify response contains the path - let response_str = String::from_utf8_lossy(&response); - assert!(response_str.contains("services/llm/worker")); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Configuration Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_custom_config() { - let config = Http2Config::new() - .max_concurrent_streams(50) - .initial_window_size(1024 * 1024) - .connect_timeout(std::time::Duration::from_secs(10)) - .request_timeout(std::time::Duration::from_secs(60)); - - assert_eq!(config.max_concurrent_streams, 50); - assert_eq!(config.initial_window_size, 1024 * 1024); - assert_eq!(config.connect_timeout, std::time::Duration::from_secs(10)); - assert_eq!(config.request_timeout, std::time::Duration::from_secs(60)); -} - -#[tokio::test] -async fn test_http2_server_with_custom_config() { - let mut config = Http2Config::new().max_concurrent_streams(100); - config.max_frame_size = 32 * 1024; - - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - config, - cancel.clone(), - ) - .await - .unwrap(); - - assert_ne!(server.local_addr().port(), 0); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Error Handling Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_connection_refused() { - let client = Http2Client::new(Http2Config::default()); - - // Try to connect to a port that's not listening - let result = client - .ask( - "127.0.0.1:1".parse().unwrap(), - "/actors/test", - "Test", - vec![], - ) - .await; - - // Should fail - assert!(result.is_err()); -} - -#[tokio::test] -async fn test_http2_server_shutdown() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client and verify it works - let client = Http2Client::new(Http2Config::default()); - let result = client.ask(addr, "/actors/test", "Test", vec![]).await; - assert!(result.is_ok()); - - // Shutdown server - server.shutdown(); - - // Give server time to shutdown - tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; - - // New connections should fail (eventually) - // Note: existing connections may still work briefly -} - -// ============================================================================ -// Stream Frame Tests -// ============================================================================ - -#[test] -fn test_stream_frame_data() { - use pulsing_actor::transport::StreamFrame; - - let frame = StreamFrame::data("token", b"hello"); - assert_eq!(frame.msg_type, "token"); - assert!(!frame.end); - assert!(frame.error.is_none()); - assert_eq!(frame.get_data(), b"hello"); -} - -#[test] -fn test_stream_frame_end() { - use pulsing_actor::transport::StreamFrame; - - let frame = StreamFrame::end(); - assert!(frame.end); - assert!(frame.error.is_none()); -} - -#[test] -fn test_stream_frame_error() { - use pulsing_actor::transport::StreamFrame; - - let frame = StreamFrame::error("something went wrong"); - assert!(frame.end); - assert!(frame.is_error()); - assert_eq!(frame.error.as_ref().unwrap(), "something went wrong"); -} - -#[test] -fn test_stream_frame_binary_roundtrip() { - use pulsing_actor::transport::StreamFrame; - - let original = StreamFrame::data("response", b"world"); - let bytes = original.to_binary(); - let parsed = StreamFrame::from_binary(&bytes).unwrap(); - - assert_eq!(parsed.msg_type, "response"); - assert_eq!(parsed.get_data(), b"world"); -} - -// ============================================================================ -// Message Mode Tests -// ============================================================================ - -#[test] -fn test_message_mode_conversion() { - use pulsing_actor::transport::MessageMode; - - assert_eq!(MessageMode::Ask.as_str(), "ask"); - assert_eq!(MessageMode::Tell.as_str(), "tell"); - assert_eq!(MessageMode::Stream.as_str(), "stream"); - - assert_eq!(MessageMode::parse("ask"), Some(MessageMode::Ask)); - assert_eq!(MessageMode::parse("TELL"), Some(MessageMode::Tell)); - assert_eq!(MessageMode::parse("Stream"), Some(MessageMode::Stream)); - assert_eq!(MessageMode::parse("invalid"), None); -} - -// ============================================================================ -// Unified Send Message Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_unified_send_message() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(TestHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client and send using unified send_message - let client = Http2Client::new(Http2Config::default()); - let response = client - .send_message(addr, "/actors/test", "TestMsg", b"test-payload".to_vec()) - .await - .unwrap(); - - // Should receive single response - assert!(response.is_single()); - - // Verify response content - let Message::Single { data, .. } = response else { - panic!("Expected single message"); - }; - let data_str = String::from_utf8_lossy(&data); - assert!(data_str.contains("/actors/test")); - assert!(data_str.contains("TestMsg")); - - // Verify ask counter - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Performance / Benchmark Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_throughput_benchmark() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::high_throughput(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Arc::new(Http2Client::new(Http2Config::high_throughput())); - - let request_count = 1000; - let start = std::time::Instant::now(); - - // Send requests sequentially - for i in 0..request_count { - let _ = client - .ask( - addr, - "/actors/bench", - "Bench", - format!("req-{}", i).into_bytes(), - ) - .await - .unwrap(); - } - - let elapsed = start.elapsed(); - let rps = request_count as f64 / elapsed.as_secs_f64(); - - println!( - "HTTP/2 Sequential Throughput: {} requests in {:?} ({:.0} req/s)", - request_count, elapsed, rps - ); - - // Should handle at least 100 req/s (very conservative) - assert!(rps > 100.0, "Throughput too low: {} req/s", rps); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_concurrent_throughput_benchmark() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::high_throughput(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Arc::new(Http2Client::new(Http2Config::high_throughput())); - - let request_count = 1000; - let concurrency = 50; - let start = std::time::Instant::now(); - - // Send concurrent requests - let mut handles = Vec::new(); - for i in 0..request_count { - let client = client.clone(); - let handle = tokio::spawn(async move { - client - .ask( - addr, - "/actors/bench", - "Bench", - format!("req-{}", i).into_bytes(), - ) - .await - }); - handles.push(handle); - - // Limit concurrency - if handles.len() >= concurrency { - let results: Vec<_> = futures::future::join_all(handles.drain(..)).await; - for r in results { - r.unwrap().unwrap(); - } - } - } - - // Wait for remaining - let results: Vec<_> = futures::future::join_all(handles).await; - for r in results { - r.unwrap().unwrap(); - } - - let elapsed = start.elapsed(); - let rps = request_count as f64 / elapsed.as_secs_f64(); - - println!( - "HTTP/2 Concurrent Throughput ({} concurrency): {} requests in {:?} ({:.0} req/s)", - concurrency, request_count, elapsed, rps - ); - - // Should be faster than sequential (relaxed for CI environments) - assert!(rps > 100.0, "Concurrent throughput too low: {} req/s", rps); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_latency_benchmark() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::low_latency(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::low_latency()); - - let request_count = 100; - let mut latencies = Vec::with_capacity(request_count); - - // Measure individual request latencies - for i in 0..request_count { - let start = std::time::Instant::now(); - let _ = client - .ask( - addr, - "/actors/latency", - "Ping", - format!("req-{}", i).into_bytes(), - ) - .await - .unwrap(); - latencies.push(start.elapsed()); - } - - // Calculate statistics - latencies.sort(); - let min = latencies.first().unwrap(); - let max = latencies.last().unwrap(); - let median = latencies[request_count / 2]; - let p99 = latencies[(request_count * 99) / 100]; - let avg: std::time::Duration = - latencies.iter().sum::() / request_count as u32; - - println!( - "HTTP/2 Latency: min={:?}, avg={:?}, median={:?}, p99={:?}, max={:?}", - min, avg, median, p99, max - ); - - // P99 should be under 500ms for localhost (relaxed for CI environments) - assert!( - p99 < std::time::Duration::from_millis(500), - "P99 latency too high: {:?}", - p99 - ); - - // Cleanup - cancel.cancel(); -} - -// ============================================================================ -// Connection Pool Tests -// ============================================================================ - -#[tokio::test] -async fn test_http2_connection_pool_reuse() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::default()); - - // Send multiple requests to trigger connection reuse - for i in 0..10 { - let _ = client - .ask( - addr, - "/actors/pool-test", - "Msg", - format!("req-{}", i).into_bytes(), - ) - .await - .unwrap(); - } - - // Check pool stats - let stats = client.stats(); - let created = stats.connections_created.load(Ordering::Relaxed); - let reused = stats.connections_reused.load(Ordering::Relaxed); - - println!("Connection Pool: created={}, reused={}", created, reused); - - // Should have created very few connections but reused many times - // HTTP/2 multiplexing means we should only need 1 connection - assert!(created <= 2, "Too many connections created: {}", created); - assert!(reused >= 8, "Not enough connection reuse: {}", reused); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_retry_on_connection_error() { - use pulsing_actor::transport::{Http2ClientBuilder, RetryConfig}; - use std::time::Duration; - - // Create client with retry - let client = Http2ClientBuilder::new() - .retry_config(RetryConfig::with_max_retries(2).initial_delay(Duration::from_millis(10))) - .connect_timeout(Duration::from_millis(100)) - .build(); - - // Try to connect to a port that doesn't exist - let result: Result, _> = client - .ask( - "127.0.0.1:1".parse().unwrap(), - "/actors/test", - "Msg", - vec![], - ) - .await; - - // Should fail after retries - assert!(result.is_err()); - let err = result.unwrap_err().to_string(); - // Error message should indicate connection failure - assert!( - err.contains("Connection") - || err.contains("connect") - || err.contains("timeout") - || err.contains("backing off"), - "Unexpected error: {}", - err - ); -} - -// ============================================================================ -// Streaming Request Tests -// ============================================================================ - -/// Handler that supports streaming requests -struct StreamingHandler { - counters: Arc, -} - -impl StreamingHandler { - fn with_counters(counters: Arc) -> Self { - Self { counters } - } -} - -#[async_trait::async_trait] -impl Http2ServerHandler for StreamingHandler { - async fn handle_message_full(&self, path: &str, msg: Message) -> anyhow::Result { - use futures::StreamExt; - self.counters.ask_count.fetch_add(1, Ordering::SeqCst); - - match msg { - Message::Single { msg_type, data } => { - // Echo single message - let mut response = format!("{}:{}:", path, msg_type).into_bytes(); - response.extend(data); - Ok(Message::single("echo", response)) - } - Message::Stream { stream, .. } => { - // Collect all stream data and echo back as single response - let mut collected = Vec::new(); - let mut stream = std::pin::pin!(stream); - while let Some(result) = stream.next().await { - match result { - Ok(Message::Single { data, .. }) => { - collected.extend(data); - } - Ok(Message::Stream { .. }) => { - return Err(anyhow::anyhow!("Nested streams not supported")); - } - Err(e) => return Err(e), - } - } - let response = format!("{}:collected:{} bytes", path, collected.len()).into_bytes(); - Ok(Message::single("stream_echo", response)) - } - } - } - - async fn handle_message_simple( - &self, - path: &str, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result { - self.counters.ask_count.fetch_add(1, Ordering::SeqCst); - let mut response = format!("{}:{}:", path, msg_type).into_bytes(); - response.extend(payload); - Ok(Message::single("", response)) - } - - async fn handle_tell( - &self, - _path: &str, - _msg_type: &str, - _payload: Vec, - ) -> anyhow::Result<()> { - self.counters.tell_count.fetch_add(1, Ordering::SeqCst); - Ok(()) - } - - async fn handle_gossip( - &self, - _payload: Vec, - _peer_addr: std::net::SocketAddr, - ) -> anyhow::Result>> { - Ok(None) - } - - fn as_any(&self) -> &dyn std::any::Any { - self - } -} - -#[tokio::test] -async fn test_http2_streaming_request() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(StreamingHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::default()); - - // Create a streaming request - let (tx, rx) = tokio::sync::mpsc::channel::>(10); - - // Send some messages through the stream - tokio::spawn(async move { - for i in 0..5 { - let msg = Message::single("chunk", format!("data-{}", i).into_bytes()); - if tx.send(Ok(msg)).await.is_err() { - break; - } - } - }); - - // Create stream message from channel - let stream_msg = Message::from_channel("StreamRequest", rx); - - // Send streaming request - let response = client - .send_message_full(addr, "/actors/stream_test", stream_msg) - .await - .unwrap(); - - // Verify response - assert!(response.is_single()); - let Message::Single { data, .. } = response else { - panic!("Expected single message"); - }; - let response_str = String::from_utf8_lossy(&data); - assert!(response_str.contains("/actors/stream_test")); - assert!(response_str.contains("collected:")); - - // Verify handler was called - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -#[tokio::test] -async fn test_http2_single_request_with_full_api() { - let counters = Arc::new(TestCounters::default()); - let handler = Arc::new(StreamingHandler::with_counters(counters.clone())); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::default()); - - // Send single message using send_message_full - let msg = Message::single("TestType", b"test-payload".to_vec()); - let response = client - .send_message_full(addr, "/actors/single_test", msg) - .await - .unwrap(); - - // Verify response - assert!(response.is_single()); - let Message::Single { data, .. } = response else { - panic!("Expected single message"); - }; - let response_str = String::from_utf8_lossy(&data); - assert!(response_str.contains("/actors/single_test")); - assert!(response_str.contains("TestType")); - - // Verify handler was called - assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); - - // Cleanup - cancel.cancel(); -} - -#[test] -fn test_request_type_conversion() { - use pulsing_actor::transport::RequestType; - - assert_eq!(RequestType::Single.as_str(), "single"); - assert_eq!(RequestType::Stream.as_str(), "stream"); - - assert_eq!(RequestType::parse("single"), Some(RequestType::Single)); - assert_eq!(RequestType::parse("STREAM"), Some(RequestType::Stream)); - assert_eq!(RequestType::parse("invalid"), None); -} - -// ============================================================================ -// Tracing Tests -// ============================================================================ - -mod tracing_tests { - use super::*; - use pulsing_actor::tracing::opentelemetry::trace::TraceContextExt; - use pulsing_actor::tracing::{TraceContext, TRACEPARENT_HEADER}; - use std::sync::Mutex; - - /// Handler that captures trace context from incoming requests - struct TracingTestHandler { - counters: Arc, - captured_traces: Arc>>>, - } - - impl TracingTestHandler { - fn new() -> Self { - Self { - counters: Arc::new(TestCounters::default()), - captured_traces: Arc::new(Mutex::new(Vec::new())), - } - } - - #[allow(dead_code)] - fn captured_traces(&self) -> Vec> { - self.captured_traces.lock().unwrap().clone() - } - } - - #[async_trait::async_trait] - impl Http2ServerHandler for TracingTestHandler { - async fn handle_message_simple( - &self, - path: &str, - msg_type: &str, - payload: Vec, - ) -> anyhow::Result { - self.counters.ask_count.fetch_add(1, Ordering::SeqCst); - - // Try to get current trace context - let trace_ctx = TraceContext::from_current(); - self.captured_traces - .lock() - .unwrap() - .push(trace_ctx.map(|t| t.to_traceparent())); - - // Echo response with trace info - let response = format!("{}:{}:{}", path, msg_type, payload.len()); - Ok(Message::single("traced", response.into_bytes())) - } - - async fn handle_tell( - &self, - _path: &str, - _msg_type: &str, - _payload: Vec, - ) -> anyhow::Result<()> { - self.counters.tell_count.fetch_add(1, Ordering::SeqCst); - Ok(()) - } - - async fn handle_gossip( - &self, - _payload: Vec, - _peer_addr: std::net::SocketAddr, - ) -> anyhow::Result>> { - Ok(None) - } - - fn as_any(&self) -> &dyn std::any::Any { - self - } - } - - #[test] - fn test_trace_context_creation() { - let ctx = TraceContext::default(); - assert_eq!(ctx.trace_id.len(), 32); - assert_eq!(ctx.span_id.len(), 16); - assert_eq!(ctx.trace_flags, 0x01); // Sampled - } - - #[test] - fn test_trace_context_to_traceparent() { - let ctx = TraceContext { - trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), - span_id: "b7ad6b7169203331".to_string(), - trace_flags: 0x01, - trace_state: None, - }; - - let header = ctx.to_traceparent(); - assert_eq!( - header, - "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" - ); - } - - #[test] - fn test_trace_context_from_traceparent() { - let header = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; - let ctx = TraceContext::from_traceparent(header).unwrap(); - - assert_eq!(ctx.trace_id, "0af7651916cd43dd8448eb211c80319c"); - assert_eq!(ctx.span_id, "b7ad6b7169203331"); - assert_eq!(ctx.trace_flags, 0x01); - } - - #[test] - fn test_trace_context_roundtrip() { - let original = TraceContext::default(); - let header = original.to_traceparent(); - let parsed = TraceContext::from_traceparent(&header).unwrap(); - - assert_eq!(original.trace_id, parsed.trace_id); - assert_eq!(original.span_id, parsed.span_id); - assert_eq!(original.trace_flags, parsed.trace_flags); - } - - #[test] - fn test_trace_context_child() { - let parent = TraceContext::default(); - let child = parent.child(); - - // Same trace ID - assert_eq!(parent.trace_id, child.trace_id); - // Different span ID - assert_ne!(parent.span_id, child.span_id); - // Same flags - assert_eq!(parent.trace_flags, child.trace_flags); - } - - #[test] - fn test_invalid_traceparent_formats() { - // Too few parts - assert!(TraceContext::from_traceparent("invalid").is_none()); - assert!(TraceContext::from_traceparent("00-abc-def").is_none()); - - // Wrong lengths - assert!(TraceContext::from_traceparent("00-short-id-01").is_none()); - assert!( - TraceContext::from_traceparent("00-0af7651916cd43dd8448eb211c80319c-short-01") - .is_none() - ); - - // Invalid hex in flags - assert!(TraceContext::from_traceparent( - "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-zz" - ) - .is_none()); - } - - #[test] - fn test_traceparent_header_constant() { - assert_eq!(TRACEPARENT_HEADER, "traceparent"); - } - - #[test] - fn test_trace_context_not_sampled() { - let ctx = TraceContext { - trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), - span_id: "b7ad6b7169203331".to_string(), - trace_flags: 0x00, // Not sampled - trace_state: None, - }; - - let header = ctx.to_traceparent(); - assert!(header.ends_with("-00")); - - let parsed = TraceContext::from_traceparent(&header).unwrap(); - assert_eq!(parsed.trace_flags, 0x00); - } - - #[test] - fn test_new_child_span_id_uniqueness() { - let ids: Vec = (0..100) - .map(|_| TraceContext::new_child_span_id()) - .collect(); - - // All IDs should be unique - let mut unique_ids = ids.clone(); - unique_ids.sort(); - unique_ids.dedup(); - assert_eq!(ids.len(), unique_ids.len()); - - // All IDs should be 16 hex chars - for id in &ids { - assert_eq!(id.len(), 16); - assert!(id.chars().all(|c| c.is_ascii_hexdigit())); - } - } - - #[tokio::test] - async fn test_http2_request_with_tracing() { - let handler = Arc::new(TracingTestHandler::new()); - let cancel = CancellationToken::new(); - - // Start server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler.clone(), - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client - let client = Http2Client::new(Http2Config::default()); - - // Send request (client should inject traceparent) - let response = client - .ask(addr, "/actors/traced", "test", b"hello".to_vec()) - .await - .unwrap(); - - // Verify response - assert!(!response.is_empty()); - - // Cleanup - cancel.cancel(); - } - - #[tokio::test] - async fn test_multiple_requests_different_traces() { - let handler = Arc::new(TracingTestHandler::new()); - let cancel = CancellationToken::new(); - - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler.clone(), - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - let client = Http2Client::new(Http2Config::default()); - - // Send multiple requests - for i in 0..3 { - let _ = client - .ask( - addr, - "/actors/test", - "type", - format!("msg-{}", i).into_bytes(), - ) - .await - .unwrap(); - } - - // Each request should have its own trace - assert_eq!(handler.counters.ask_count.load(Ordering::SeqCst), 3); - - cancel.cancel(); - } - - #[test] - fn test_trace_context_otel_conversion() { - let ctx = TraceContext { - trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), - span_id: "b7ad6b7169203331".to_string(), - trace_flags: 0x01, - trace_state: None, - }; - - // Convert to OpenTelemetry context - let otel_ctx = ctx.to_otel_context(); - - // The context should be valid (not panic) - assert!(!otel_ctx - .span() - .span_context() - .trace_id() - .to_string() - .is_empty()); - } -} - -// ============================================================================ -// REST API Endpoint Tests -// ============================================================================ - -mod rest_api_tests { - use super::*; - - /// Test /cluster/members endpoint - #[tokio::test] - async fn test_cluster_members_endpoint() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - let _server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler.clone(), - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - // Test cluster members endpoint - let members = handler.cluster_members().await; - assert!(members.is_array()); - let members_array = members.as_array().unwrap(); - assert_eq!(members_array.len(), 2); - - // Verify member structure - let member = &members_array[0]; - assert!(member.get("node_id").is_some()); - assert!(member.get("addr").is_some()); - assert!(member.get("status").is_some()); - - cancel.cancel(); - } - - /// Test /actors endpoint - #[tokio::test] - async fn test_actors_list_endpoint() { - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - let _server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler.clone(), - Http2Config::default(), - cancel.clone(), - ) - .await - .unwrap(); - - // Test actors list (user only) - let actors = handler.actors_list(false).await; - assert!(actors.is_array()); - let actors_array = actors.as_array().unwrap(); - assert_eq!(actors_array.len(), 2); // Only user actors - - // Verify actor structure - let actor = &actors_array[0]; - assert!(actor.get("name").is_some()); - assert!(actor.get("type").is_some()); - assert_eq!(actor.get("type").unwrap(), "user"); - - // Test actors list (include internal) - let all_actors = handler.actors_list(true).await; - let all_actors_array = all_actors.as_array().unwrap(); - assert_eq!(all_actors_array.len(), 3); // Includes system actor - - cancel.cancel(); - } - - /// Test actor metadata in actors list - #[tokio::test] - async fn test_actors_list_metadata() { - let handler = Arc::new(TestHandler::new()); - - // Test actors list has metadata - let actors = handler.actors_list(false).await; - let actors_array = actors.as_array().unwrap(); - let actor = &actors_array[0]; - - // Verify metadata fields - assert!(actor.get("actor_id").is_some()); - assert!(actor.get("class").is_some()); - assert!(actor.get("module").is_some()); - - // Verify values - assert_eq!(actor.get("class").unwrap(), "Counter"); - assert_eq!(actor.get("module").unwrap(), "__main__"); - } - - /// Test health check endpoint returns expected structure - #[tokio::test] - async fn test_health_check_endpoint() { - let handler = Arc::new(TestHandler::new()); - - let health = handler.health_check().await; - - assert!(health.get("status").is_some()); - assert_eq!(health.get("status").unwrap(), "healthy"); - assert!(health.get("ask_count").is_some()); - assert!(health.get("tell_count").is_some()); - } -} - -// ============================================================================ -// TLS Tests (requires `tls` feature) -// ============================================================================ - -#[cfg(feature = "tls")] -mod tls_tests { - use super::*; - use pulsing_actor::transport::http2::TlsConfig; - - /// Test TLS configuration creation from passphrase - #[test] - fn test_tls_config_from_passphrase() { - let config = TlsConfig::from_passphrase("test-cluster-password"); - assert!(config.is_ok(), "TLS config creation failed: {:?}", config); - } - - /// Test that same passphrase produces deterministic CA - #[test] - fn test_tls_deterministic_ca() { - let config1 = TlsConfig::from_passphrase("deterministic-test-password").unwrap(); - let config2 = TlsConfig::from_passphrase("deterministic-test-password").unwrap(); - - // Both should have the same passphrase hash - assert_eq!(config1.passphrase_hash(), config2.passphrase_hash()); - } - - /// Test that different passphrase produces different CA - #[test] - fn test_tls_different_passphrase() { - let config1 = TlsConfig::from_passphrase("password-one").unwrap(); - let config2 = TlsConfig::from_passphrase("password-two").unwrap(); - - // Different passwords should have different hashes - assert_ne!(config1.passphrase_hash(), config2.passphrase_hash()); - } - - /// Test TLS-enabled HTTP/2 server and client communication - #[tokio::test] - async fn test_tls_server_client_communication() { - let tls_config = TlsConfig::from_passphrase("test-cluster-tls").unwrap(); - - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Create HTTP/2 config with TLS - let http2_config = Http2Config::default().tls_config(tls_config.clone()); - - // Start TLS server - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler.clone(), - http2_config.clone(), - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create TLS client with same passphrase - let client = Http2Client::new(http2_config); - - // Send request over TLS - let response = client - .ask(addr, "/actors/test", "test-msg", b"hello tls".to_vec()) - .await; - - // Should succeed - assert!(response.is_ok(), "TLS request failed: {:?}", response); - - let body = response.unwrap(); - let response_str = String::from_utf8_lossy(&body); - assert!( - response_str.contains("hello tls"), - "Response should contain original message" - ); - - cancel.cancel(); - } - - /// Test that different passphrase fails TLS handshake - #[tokio::test] - async fn test_tls_different_passphrase_fails() { - let server_tls = TlsConfig::from_passphrase("server-password").unwrap(); - let client_tls = TlsConfig::from_passphrase("wrong-password").unwrap(); - - let handler = Arc::new(TestHandler::new()); - let cancel = CancellationToken::new(); - - // Start TLS server - let server_config = Http2Config::default().tls_config(server_tls); - let server = Http2Server::new( - "127.0.0.1:0".parse().unwrap(), - handler, - server_config, - cancel.clone(), - ) - .await - .unwrap(); - - let addr = server.local_addr(); - - // Create client with WRONG passphrase - let client_config = Http2Config::default().tls_config(client_tls); - let client = Http2Client::new(client_config); - - // Request should fail due to TLS handshake failure - let response = client - .ask(addr, "/actors/test", "test", b"test".to_vec()) - .await; - - // Should fail - assert!( - response.is_err(), - "Request with different passphrase should fail" - ); - - cancel.cancel(); - } -} diff --git a/crates/pulsing-actor/tests/integration/main.rs b/crates/pulsing-actor/tests/integration/main.rs new file mode 100644 index 000000000..34e9786ba --- /dev/null +++ b/crates/pulsing-actor/tests/integration/main.rs @@ -0,0 +1,5 @@ +#[path = "../common/mod.rs"] +mod common; + +mod multi_node_tests; +mod single_node_tests; diff --git a/crates/pulsing-actor/tests/multi_node_tests.rs b/crates/pulsing-actor/tests/integration/multi_node_tests.rs similarity index 97% rename from crates/pulsing-actor/tests/multi_node_tests.rs rename to crates/pulsing-actor/tests/integration/multi_node_tests.rs index d38a44f49..44696cd88 100644 --- a/crates/pulsing-actor/tests/multi_node_tests.rs +++ b/crates/pulsing-actor/tests/integration/multi_node_tests.rs @@ -1,6 +1,7 @@ //! Multi-node cluster integration tests use pulsing_actor::actor::{ActorAddress, ActorPath}; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use pulsing_actor::ActorSystemOpsExt; use std::sync::atomic::{AtomicI32, Ordering}; @@ -40,14 +41,20 @@ struct Echo; #[async_trait] impl Actor for Echo { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("Ping") { let ping: Ping = msg.unpack()?; return Message::pack(&Pong { result: ping.value * 2, }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -57,7 +64,11 @@ struct Counter { #[async_trait] impl Actor for Counter { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("Increment") { let new_count = self.count.fetch_add(1, Ordering::SeqCst) + 1; return Message::pack(&CountResponse { count: new_count }); @@ -66,7 +77,9 @@ impl Actor for Counter { let count = self.count.load(Ordering::SeqCst); return Message::pack(&CountResponse { count }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -151,7 +164,7 @@ mod two_node_tests { // Multi-Node Cluster Tests // ============================================================================ -mod multi_node_tests { +mod multi_node { use super::*; #[tokio::test] diff --git a/crates/pulsing-actor/tests/integration_tests.rs b/crates/pulsing-actor/tests/integration/single_node_tests.rs similarity index 93% rename from crates/pulsing-actor/tests/integration_tests.rs rename to crates/pulsing-actor/tests/integration/single_node_tests.rs index 4056a78b5..3c852d49d 100644 --- a/crates/pulsing-actor/tests/integration_tests.rs +++ b/crates/pulsing-actor/tests/integration/single_node_tests.rs @@ -1,6 +1,7 @@ //! Integration tests for the complete actor system use pulsing_actor::actor::{ActorAddress, ActorPath}; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use pulsing_actor::ActorSystemOpsExt; use std::sync::atomic::{AtomicI32, AtomicUsize, Ordering}; @@ -44,7 +45,11 @@ struct EchoActor { #[async_trait] impl Actor for EchoActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("Ping") { let ping: Ping = msg.unpack()?; self.echo_count.fetch_add(1, Ordering::SeqCst); @@ -52,7 +57,9 @@ impl Actor for EchoActor { result: ping.value * 2, }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -62,7 +69,11 @@ struct Accumulator { #[async_trait] impl Actor for Accumulator { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("Accumulate") { let acc: Accumulate = msg.unpack()?; @@ -72,7 +83,9 @@ impl Actor for Accumulator { if msg_type.ends_with("GetTotal") { return Message::pack(&TotalResponse { total: self.total }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -80,7 +93,7 @@ impl Actor for Accumulator { // Single Node Integration Tests // ============================================================================ -mod single_node_tests { +mod single_node { use super::*; #[tokio::test] @@ -333,16 +346,20 @@ mod error_tests { &mut self, msg: Message, _ctx: &mut ActorContext, - ) -> anyhow::Result { + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("CrashMessage") { self.crash_count.fetch_add(1, Ordering::SeqCst); - return Err(anyhow::anyhow!("Intentional crash!")); + return Err(PulsingError::from(RuntimeError::Other( + "Intentional crash!".into(), + ))); } if msg.msg_type().ends_with("Ping") { let ping: Ping = msg.unpack()?; return Message::pack(&Pong { result: ping.value }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -361,16 +378,18 @@ mod error_tests { .await .unwrap(); - // Send crash message + // receive 返回 Err 时只把错误返回给调用者,actor 不退出 let result: Result = actor_ref.ask(CrashMessage).await; assert!(result.is_err()); assert_eq!(crash_count.load(Ordering::SeqCst), 1); - // With supervision model, errors cause actor to crash (unless supervision is configured) - // So subsequent messages should fail - tokio::time::sleep(std::time::Duration::from_millis(10)).await; + // Actor 仍存活,后续消息应正常处理 let result2: Result = actor_ref.ask(Ping { value: 42 }).await; - assert!(result2.is_err(), "Actor should be dead after error"); + assert!( + result2.is_ok(), + "Actor should still be alive after receive error" + ); + assert_eq!(result2.unwrap().result, 42); system.shutdown().await.unwrap(); } @@ -390,17 +409,14 @@ mod error_tests { .await .unwrap(); - // First crash message crashes the actor - let _: Result = actor_ref.ask(CrashMessage).await; + // receive 返回 Err 时只把错误返回给调用者,actor 不退出 + let r1: Result = actor_ref.ask(CrashMessage).await; + assert!(r1.is_err()); assert_eq!(crash_count.load(Ordering::SeqCst), 1); - // Actor is now dead - subsequent messages fail with mailbox closed - tokio::time::sleep(std::time::Duration::from_millis(10)).await; - let result: Result = actor_ref.ask(CrashMessage).await; - assert!(result.is_err(), "Actor should be dead after first error"); - - // Counter doesn't increment because actor is dead - assert_eq!(crash_count.load(Ordering::SeqCst), 1); + let r2: Result = actor_ref.ask(CrashMessage).await; + assert!(r2.is_err()); + assert_eq!(crash_count.load(Ordering::SeqCst), 2); system.shutdown().await.unwrap(); } @@ -443,12 +459,12 @@ mod lifecycle_tests { #[async_trait] impl Actor for LifecycleTracker { - async fn on_start(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.events.lock().await.push("started".to_string()); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.events.lock().await.push("stopped".to_string()); Ok(()) } @@ -457,7 +473,7 @@ mod lifecycle_tests { &mut self, msg: Message, _ctx: &mut ActorContext, - ) -> anyhow::Result { + ) -> pulsing_actor::error::Result { self.events .lock() .await diff --git a/crates/pulsing-actor/tests/actor_tests.rs b/crates/pulsing-actor/tests/unit/actor/actor_tests.rs similarity index 91% rename from crates/pulsing-actor/tests/actor_tests.rs rename to crates/pulsing-actor/tests/unit/actor/actor_tests.rs index 254a9f7eb..ac084e4f3 100644 --- a/crates/pulsing-actor/tests/actor_tests.rs +++ b/crates/pulsing-actor/tests/unit/actor/actor_tests.rs @@ -1,5 +1,6 @@ //! Actor core functionality tests +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::Arc; @@ -51,7 +52,11 @@ struct Counter { #[async_trait] impl Actor for Counter { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("Ping") { let ping: Ping = msg.unpack()?; @@ -72,9 +77,14 @@ impl Actor for Counter { return Message::pack(&StateResponse { value: self.count }); } if msg_type.ends_with("ErrorMessage") { - return Err(anyhow::anyhow!("Intentional error for testing")); + return Err(PulsingError::from(RuntimeError::Other( + "Intentional error for testing".into(), + ))); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(PulsingError::from(RuntimeError::Other(format!( + "Unknown message type: {}", + msg_type + )))) } } @@ -86,21 +96,27 @@ struct LifecycleActor { #[async_trait] impl Actor for LifecycleActor { - async fn on_start(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.start_count.fetch_add(1, Ordering::SeqCst); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.stop_count.fetch_add(1, Ordering::SeqCst); Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("Ping") { return Message::pack(&Pong { result: 0 }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } @@ -272,12 +288,12 @@ mod error_tests { let result: Result = actor_ref.ask(ErrorMessage).await; assert!(result.is_err()); - // With the supervision model, errors cause the actor to crash - // (unless supervision is configured to restart it) - // So subsequent messages will fail with "mailbox closed" - tokio::time::sleep(std::time::Duration::from_millis(10)).await; + // receive 返回 Err 时只把错误返回给调用者,actor 不退出 let result2: Result = actor_ref.ask(Ping { value: 1 }).await; - assert!(result2.is_err(), "Actor should be dead after error"); + assert!( + result2.is_ok(), + "Actor should still be alive after receive error" + ); let _ = system.shutdown().await; } diff --git a/crates/pulsing-actor/tests/address_tests.rs b/crates/pulsing-actor/tests/unit/actor/address_tests.rs similarity index 95% rename from crates/pulsing-actor/tests/address_tests.rs rename to crates/pulsing-actor/tests/unit/actor/address_tests.rs index 321090975..75c172128 100644 --- a/crates/pulsing-actor/tests/address_tests.rs +++ b/crates/pulsing-actor/tests/unit/actor/address_tests.rs @@ -1,6 +1,7 @@ //! Comprehensive tests for the actor addressing system use pulsing_actor::actor::{ActorId, ActorPath}; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -31,7 +32,11 @@ struct IdentityActor { #[async_trait] impl Actor for IdentityActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { if msg.msg_type().ends_with("EchoMsg") { let echo: EchoMsg = msg.unpack()?; self.call_count.fetch_add(1, Ordering::SeqCst); @@ -40,7 +45,9 @@ impl Actor for IdentityActor { from_node: self.node_name.clone(), }); } - Err(anyhow::anyhow!("Unknown message")) + Err(PulsingError::from(RuntimeError::Other( + "Unknown message".into(), + ))) } } diff --git a/crates/pulsing-actor/tests/context_tests.rs b/crates/pulsing-actor/tests/unit/actor/context_tests.rs similarity index 87% rename from crates/pulsing-actor/tests/context_tests.rs rename to crates/pulsing-actor/tests/unit/actor/context_tests.rs index d49178962..6d5ad722e 100644 --- a/crates/pulsing-actor/tests/context_tests.rs +++ b/crates/pulsing-actor/tests/unit/actor/context_tests.rs @@ -15,7 +15,11 @@ struct Target; #[async_trait] impl Actor for Target { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let ping: Ping = msg.unpack()?; Message::pack(&Pong { value: ping.value + 1, @@ -29,7 +33,11 @@ struct Forwarder { #[async_trait] impl Actor for Forwarder { - async fn receive(&mut self, msg: Message, ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let ping: Ping = msg.unpack()?; let target_ref = ctx.actor_ref(&self.target).await?; let pong: Pong = target_ref.ask(Ping { value: ping.value }).await?; diff --git a/crates/pulsing-actor/tests/mailbox_tests.rs b/crates/pulsing-actor/tests/unit/actor/mailbox_tests.rs similarity index 100% rename from crates/pulsing-actor/tests/mailbox_tests.rs rename to crates/pulsing-actor/tests/unit/actor/mailbox_tests.rs diff --git a/crates/pulsing-actor/tests/unit/actor/mod.rs b/crates/pulsing-actor/tests/unit/actor/mod.rs new file mode 100644 index 000000000..517f8b198 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/actor/mod.rs @@ -0,0 +1,4 @@ +mod actor_tests; +mod address_tests; +mod context_tests; +mod mailbox_tests; diff --git a/crates/pulsing-actor/tests/unit/behavior/context_tests.rs b/crates/pulsing-actor/tests/unit/behavior/context_tests.rs new file mode 100644 index 000000000..566385076 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/behavior/context_tests.rs @@ -0,0 +1,69 @@ +//! Unit tests for BehaviorContext: name, actor_id, self_ref, is_cancelled. + +use pulsing_actor::behavior::{stateless, BehaviorAction, BehaviorWrapper}; +use pulsing_actor::prelude::*; +use serde::{Deserialize, Serialize}; +use std::sync::Arc; +use tokio::sync::Mutex; + +#[derive(Serialize, Deserialize, Debug)] +enum CtxMsg { + RecordContext, + Ping, +} + +#[derive(Serialize, Deserialize, Debug, Default)] +struct ContextSnapshot { + name: Option, + actor_id_str: Option, + is_cancelled: Option, +} + +#[tokio::test] +async fn test_behavior_context_name_and_actor_id() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let snapshot: Arc> = Arc::new(Mutex::new(ContextSnapshot::default())); + let snap = snapshot.clone(); + let behavior = stateless(move |msg: CtxMsg, ctx| { + let snap = snap.clone(); + Box::pin(async move { + if matches!(msg, CtxMsg::RecordContext) { + let mut s = snap.lock().await; + s.name = Some(ctx.name().to_string()); + s.actor_id_str = Some(ctx.actor_id().to_string()); + s.is_cancelled = Some(ctx.is_cancelled()); + } + BehaviorAction::Same + }) + }); + let wrapper: BehaviorWrapper = behavior.into_actor(); + let actor_ref = system.spawn_named("test/ctx-test", wrapper).await.unwrap(); + actor_ref.tell(CtxMsg::RecordContext).await.unwrap(); + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + let s = snapshot.lock().await; + assert_eq!(s.name.as_deref(), Some("test/ctx-test")); + assert!(!s.actor_id_str.as_deref().unwrap_or("").is_empty()); + assert_eq!(s.is_cancelled, Some(false)); + system.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn test_behavior_context_self_ref_resolves() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let behavior = stateless(|_msg: CtxMsg, ctx| { + Box::pin(async move { + let self_ref = ctx.self_ref(); + assert_eq!(self_ref.name(), "test/self-ref-actor"); + let untyped = self_ref.as_untyped(); + assert!(untyped.is_ok()); + BehaviorAction::Same + }) + }); + let wrapper: BehaviorWrapper = behavior.into_actor(); + let actor_ref = system + .spawn_named("test/self-ref-actor", wrapper) + .await + .unwrap(); + let _: () = actor_ref.ask(CtxMsg::Ping).await.unwrap(); + system.shutdown().await.unwrap(); +} diff --git a/crates/pulsing-actor/tests/unit/behavior/core_tests.rs b/crates/pulsing-actor/tests/unit/behavior/core_tests.rs new file mode 100644 index 000000000..a0ae7bbab --- /dev/null +++ b/crates/pulsing-actor/tests/unit/behavior/core_tests.rs @@ -0,0 +1,78 @@ +//! Unit tests for behavior core: Behavior creation, message handling, BehaviorAction. + +use pulsing_actor::behavior::{stateful, stateless, Behavior, BehaviorAction, BehaviorWrapper}; +use pulsing_actor::prelude::*; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug)] +enum TestMsg { + Ping, + Add(i32), + Get, +} + +#[tokio::test] +async fn test_stateful_behavior_creation() { + let _behavior: Behavior = stateful(0i32, |count, msg, _ctx| match msg { + TestMsg::Ping => BehaviorAction::Same, + TestMsg::Add(n) => { + *count += n; + BehaviorAction::Same + } + TestMsg::Get => BehaviorAction::Same, + }); +} + +#[tokio::test] +async fn test_stateless_behavior_creation() { + let _behavior: Behavior = stateless(|msg, _ctx| { + Box::pin(async move { + match msg { + TestMsg::Ping => BehaviorAction::Same, + TestMsg::Add(_) => BehaviorAction::stop(), + TestMsg::Get => BehaviorAction::Same, + } + }) + }); +} + +#[tokio::test] +async fn test_behavior_action_stop_helpers() { + let stop_none = BehaviorAction::<()>::stop(); + assert!(stop_none.is_stop()); + let stop_reason = BehaviorAction::<()>::stop_with_reason("done"); + assert!(stop_reason.is_stop()); + assert!(matches!(stop_reason, BehaviorAction::Stop(Some(r)) if r == "done")); +} + +#[tokio::test] +async fn test_behavior_wrapper_receive_same() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let counter = stateful(0i32, |count, msg: TestMsg, _ctx| match msg { + TestMsg::Add(n) => { + *count += n; + BehaviorAction::Same + } + _ => BehaviorAction::Same, + }); + let wrapper: BehaviorWrapper = counter.into_actor(); + let actor_ref = system.spawn(wrapper).await.unwrap(); + actor_ref.tell(TestMsg::Add(10)).await.unwrap(); + actor_ref.tell(TestMsg::Add(5)).await.unwrap(); + let _: () = actor_ref.ask(TestMsg::Get).await.unwrap(); + system.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn test_behavior_wrapper_pack_unpack() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let echo = stateless(|msg: TestMsg, _ctx| { + Box::pin(async move { + let _ = msg; + BehaviorAction::Same + }) + }); + let wrapper: BehaviorWrapper = echo.into_actor(); + let _ref = system.spawn(wrapper).await.unwrap(); + system.shutdown().await.unwrap(); +} diff --git a/crates/pulsing-actor/tests/unit/behavior/mod.rs b/crates/pulsing-actor/tests/unit/behavior/mod.rs new file mode 100644 index 000000000..71793d175 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/behavior/mod.rs @@ -0,0 +1,3 @@ +mod context_tests; +mod core_tests; +mod reference_tests; diff --git a/crates/pulsing-actor/tests/unit/behavior/reference_tests.rs b/crates/pulsing-actor/tests/unit/behavior/reference_tests.rs new file mode 100644 index 000000000..a66e2adb2 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/behavior/reference_tests.rs @@ -0,0 +1,100 @@ +//! Unit tests for TypedRef: tell, ask, ask_timeout, as_untyped, is_alive. + +use pulsing_actor::behavior::{stateless, BehaviorAction, BehaviorWrapper}; +use pulsing_actor::prelude::*; +use serde::{Deserialize, Serialize}; +use std::time::Duration; + +#[derive(Serialize, Deserialize, Debug)] +struct Ping { + value: i32, +} + +#[derive(Serialize, Deserialize, Debug)] +struct Pong { + result: i32, +} + +#[derive(Serialize, Deserialize, Debug)] +enum RefTestMsg { + Ping, +} + +/// Echo actor that doubles Ping.value for Pong.result +struct EchoActor; + +#[async_trait] +impl Actor for EchoActor { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { + if msg.msg_type().ends_with("Ping") { + let ping: Ping = msg.unpack()?; + return Message::pack(&Pong { + result: ping.value * 2, + }); + } + Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other("unknown".to_string()), + )) + } +} + +#[tokio::test] +async fn test_typed_ref_tell_and_ask() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let echo = system + .spawn_named("test/typed-echo", EchoActor) + .await + .unwrap(); + let typed = pulsing_actor::behavior::TypedRef::::new("test/typed-echo", echo.clone()); + assert_eq!(typed.name(), "test/typed-echo"); + assert!(typed.is_alive()); + + typed.tell(Ping { value: 1 }).await.unwrap(); + let pong: Pong = typed.ask(Ping { value: 2 }).await.unwrap(); + assert_eq!(pong.result, 4); + + let untyped = typed.as_untyped().unwrap(); + assert!(untyped.id() == echo.id()); + system.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn test_typed_ref_ask_timeout_ok() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let echo = system + .spawn_named("test/typed-echo-timeout", EchoActor) + .await + .unwrap(); + let typed = pulsing_actor::behavior::TypedRef::::new("test/typed-echo-timeout", echo); + let pong: Pong = typed + .ask_timeout(Ping { value: 10 }, Duration::from_secs(2)) + .await + .unwrap(); + assert_eq!(pong.result, 20); + system.shutdown().await.unwrap(); +} + +#[tokio::test] +async fn test_typed_ref_from_context_typed_ref() { + let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); + let behavior = stateless(|_msg: RefTestMsg, ctx| { + Box::pin(async move { + let other: pulsing_actor::behavior::TypedRef = + ctx.typed_ref("test/typed-echo-from-ctx"); + assert_eq!(other.name(), "test/typed-echo-from-ctx"); + BehaviorAction::Same + }) + }); + let wrapper: BehaviorWrapper = behavior.into_actor(); + let _echo_ref = system + .spawn_named("test/typed-echo-from-ctx", EchoActor) + .await + .unwrap(); + let actor = system.spawn_named("test/caller", wrapper).await.unwrap(); + let _: () = actor.ask(RefTestMsg::Ping).await.unwrap(); + system.shutdown().await.unwrap(); +} diff --git a/crates/pulsing-actor/tests/cluster_tests.rs b/crates/pulsing-actor/tests/unit/cluster/gossip_tests.rs similarity index 98% rename from crates/pulsing-actor/tests/cluster_tests.rs rename to crates/pulsing-actor/tests/unit/cluster/gossip_tests.rs index c64bd6ec0..ea31fa383 100644 --- a/crates/pulsing-actor/tests/cluster_tests.rs +++ b/crates/pulsing-actor/tests/unit/cluster/gossip_tests.rs @@ -1,9 +1,6 @@ //! Cluster and Gossip protocol tests //! //! This file contains integration tests for cluster functionality. -//! Unit tests for member types are in tests/cluster/member_tests.rs - -mod cluster; use pulsing_actor::actor::{ActorId, NodeId}; use pulsing_actor::cluster::GossipConfig; diff --git a/crates/pulsing-actor/tests/cluster/member_tests.rs b/crates/pulsing-actor/tests/unit/cluster/member_tests.rs similarity index 100% rename from crates/pulsing-actor/tests/cluster/member_tests.rs rename to crates/pulsing-actor/tests/unit/cluster/member_tests.rs diff --git a/crates/pulsing-actor/tests/unit/cluster/mod.rs b/crates/pulsing-actor/tests/unit/cluster/mod.rs new file mode 100644 index 000000000..15a2a15f4 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/cluster/mod.rs @@ -0,0 +1,3 @@ +mod gossip_tests; +mod member_tests; +mod naming_tests; diff --git a/crates/pulsing-actor/tests/cluster/naming_tests.rs b/crates/pulsing-actor/tests/unit/cluster/naming_tests.rs similarity index 98% rename from crates/pulsing-actor/tests/cluster/naming_tests.rs rename to crates/pulsing-actor/tests/unit/cluster/naming_tests.rs index 0e5a332c6..f49d6ebc6 100644 --- a/crates/pulsing-actor/tests/cluster/naming_tests.rs +++ b/crates/pulsing-actor/tests/unit/cluster/naming_tests.rs @@ -32,7 +32,11 @@ struct TestActor; #[async_trait::async_trait] impl Actor for TestActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { // Echo back the message Ok(msg) } @@ -195,7 +199,11 @@ impl Actor for MetadataActor { self.metadata.clone() } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { Ok(msg) } } diff --git a/crates/pulsing-actor/tests/unit/main.rs b/crates/pulsing-actor/tests/unit/main.rs new file mode 100644 index 000000000..69c882134 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/main.rs @@ -0,0 +1,8 @@ +#[path = "../common/mod.rs"] +mod common; + +mod actor; +mod behavior; +mod cluster; +mod system; +mod transport; diff --git a/crates/pulsing-actor/tests/unit/system/mod.rs b/crates/pulsing-actor/tests/unit/system/mod.rs new file mode 100644 index 000000000..47b11184b --- /dev/null +++ b/crates/pulsing-actor/tests/unit/system/mod.rs @@ -0,0 +1,2 @@ +mod supervision_tests; +mod system_actor_tests; diff --git a/crates/pulsing-actor/tests/supervision_tests.rs b/crates/pulsing-actor/tests/unit/system/supervision_tests.rs similarity index 58% rename from crates/pulsing-actor/tests/supervision_tests.rs rename to crates/pulsing-actor/tests/unit/system/supervision_tests.rs index 227d6d484..adbb209a2 100644 --- a/crates/pulsing-actor/tests/supervision_tests.rs +++ b/crates/pulsing-actor/tests/unit/system/supervision_tests.rs @@ -1,3 +1,4 @@ +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use pulsing_actor::supervision::{BackoffStrategy, SupervisionSpec}; use std::sync::atomic::{AtomicU32, Ordering}; @@ -11,11 +12,15 @@ struct FailingActor { #[async_trait] impl Actor for FailingActor { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let count = self.counter.fetch_add(1, Ordering::SeqCst) + 1; if count == self.fail_at { - return Err(anyhow::anyhow!("Boom!")); + return Err(PulsingError::from(RuntimeError::Other("Boom!".into()))); } // Echo @@ -55,15 +60,11 @@ async fn test_restart_on_failure() { let resp = actor_ref.send(Message::single("ping", b"1")).await; assert!(resp.is_ok()); - // 2nd message - failure (should crash and restart) + // 2nd message - receive 返回 Err,错误返回给调用者,actor 不退出、不重启 let resp = actor_ref.send(Message::single("ping", b"2")).await; - assert!(resp.is_err()); // The ask fails because the actor crashed handling it + assert!(resp.is_err()); - // Wait a bit for restart - tokio::time::sleep(Duration::from_millis(50)).await; - - // 3rd message - success (new instance) - // Note: counter is shared, so it will continue from 2 -> 3 + // 3rd message - 同一实例仍存活,继续处理 let resp = actor_ref.send(Message::single("ping", b"3")).await; assert!(resp.is_ok()); @@ -79,21 +80,21 @@ async fn test_restart_on_failure() { #[tokio::test] async fn test_max_restarts_exceeded() { + // receive 返回 Err 不会导致 actor 退出,因此不会触发 restart;factory 只被调用一次 let system = ActorSystem::new(SystemConfig::standalone()).await.unwrap(); let counter = Arc::new(AtomicU32::new(0)); let counter_clone = counter.clone(); - // Fail immediately let factory = move || { counter_clone.fetch_add(1, Ordering::SeqCst); Ok(FailingActor { - counter: Arc::new(AtomicU32::new(0)), // Unused - fail_at: 1, // Fail immediately + counter: Arc::new(AtomicU32::new(0)), + fail_at: 1, // 第 1 条消息返回 Err }) }; let spec = SupervisionSpec::on_failure() - .with_max_restarts(2) // Allow 2 restarts + .with_max_restarts(2) .with_backoff(BackoffStrategy { min: Duration::from_millis(1), max: Duration::from_millis(1), @@ -109,32 +110,13 @@ async fn test_max_restarts_exceeded() { .await .unwrap(); - // 1st crash - let _ = actor_ref.send(Message::single("ping", b"1")).await; - tokio::time::sleep(Duration::from_millis(10)).await; - - // 2nd crash - let _ = actor_ref.send(Message::single("ping", b"2")).await; - tokio::time::sleep(Duration::from_millis(10)).await; - - // 3rd crash - let _ = actor_ref.send(Message::single("ping", b"3")).await; - tokio::time::sleep(Duration::from_millis(10)).await; - - // Should be dead now (Initial start + 2 restarts = 3 failures. Next attempt stops.) - // Wait for supervision loop to exit - tokio::time::sleep(Duration::from_millis(50)).await; - - // Send message to dead actor - let resp = actor_ref.send(Message::single("ping", b"4")).await; - assert!(resp.is_err()); // Mailbox closed - - // Check factory calls: Initial + 2 restarts = 3 calls - // Actually, if it crashes 3 times: - // 1. Start (count=1), Receive -> Crash - // 2. Restart 1 (count=2), Receive -> Crash - // 3. Restart 2 (count=3), Receive -> Crash - // 4. Max restarts exceeded -> Stop - // So factory called 3 times. - assert_eq!(counter.load(Ordering::SeqCst), 3); + // 第 1 条消息:receive 返回 Err,只回传错误,actor 不退出 + let r1 = actor_ref.send(Message::single("ping", b"1")).await; + assert!(r1.is_err()); + assert_eq!(counter.load(Ordering::SeqCst), 1); // factory 只调用 1 次 + + // 第 2 条消息:同一实例,count=2 != fail_at(1),返回 Ok + let r2 = actor_ref.send(Message::single("ping", b"2")).await; + assert!(r2.is_ok()); + assert_eq!(counter.load(Ordering::SeqCst), 1); // 无重启 } diff --git a/crates/pulsing-actor/tests/system_actor_tests.rs b/crates/pulsing-actor/tests/unit/system/system_actor_tests.rs similarity index 100% rename from crates/pulsing-actor/tests/system_actor_tests.rs rename to crates/pulsing-actor/tests/unit/system/system_actor_tests.rs diff --git a/crates/pulsing-actor/tests/unit/transport/benchmark_tests.rs b/crates/pulsing-actor/tests/unit/transport/benchmark_tests.rs new file mode 100644 index 000000000..ab8de55fa --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/benchmark_tests.rs @@ -0,0 +1,169 @@ +//! HTTP/2 performance and benchmark tests + +use crate::common::fixtures::TestHandler; +use pulsing_actor::transport::{Http2Client, Http2Config, Http2Server}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +#[tokio::test] +async fn test_http2_throughput_benchmark() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::high_throughput(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Arc::new(Http2Client::new(Http2Config::high_throughput())); + + let request_count = 1000; + let start = std::time::Instant::now(); + + for i in 0..request_count { + let _ = client + .ask( + addr, + "/actors/bench", + "Bench", + format!("req-{}", i).into_bytes(), + ) + .await + .unwrap(); + } + + let elapsed = start.elapsed(); + let rps = request_count as f64 / elapsed.as_secs_f64(); + + println!( + "HTTP/2 Sequential Throughput: {} requests in {:?} ({:.0} req/s)", + request_count, elapsed, rps + ); + + assert!(rps > 100.0, "Throughput too low: {} req/s", rps); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_concurrent_throughput_benchmark() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::high_throughput(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Arc::new(Http2Client::new(Http2Config::high_throughput())); + + let request_count = 1000; + let concurrency = 50; + let start = std::time::Instant::now(); + + let mut handles = Vec::new(); + for i in 0..request_count { + let client = client.clone(); + let handle = tokio::spawn(async move { + client + .ask( + addr, + "/actors/bench", + "Bench", + format!("req-{}", i).into_bytes(), + ) + .await + }); + handles.push(handle); + + if handles.len() >= concurrency { + let results: Vec<_> = futures::future::join_all(handles.drain(..)).await; + for r in results { + r.unwrap().unwrap(); + } + } + } + + let results: Vec<_> = futures::future::join_all(handles).await; + for r in results { + r.unwrap().unwrap(); + } + + let elapsed = start.elapsed(); + let rps = request_count as f64 / elapsed.as_secs_f64(); + + println!( + "HTTP/2 Concurrent Throughput ({} concurrency): {} requests in {:?} ({:.0} req/s)", + concurrency, request_count, elapsed, rps + ); + + assert!(rps > 100.0, "Concurrent throughput too low: {} req/s", rps); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_latency_benchmark() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::low_latency(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::low_latency()); + + let request_count = 100; + let mut latencies = Vec::with_capacity(request_count); + + for i in 0..request_count { + let start = std::time::Instant::now(); + let _ = client + .ask( + addr, + "/actors/latency", + "Ping", + format!("req-{}", i).into_bytes(), + ) + .await + .unwrap(); + latencies.push(start.elapsed()); + } + + latencies.sort(); + let min = latencies.first().unwrap(); + let max = latencies.last().unwrap(); + let median = latencies[request_count / 2]; + let p99 = latencies[(request_count * 99) / 100]; + let avg: std::time::Duration = + latencies.iter().sum::() / request_count as u32; + + println!( + "HTTP/2 Latency: min={:?}, avg={:?}, median={:?}, p99={:?}, max={:?}", + min, avg, median, p99, max + ); + + assert!( + p99 < std::time::Duration::from_millis(500), + "P99 latency too high: {:?}", + p99 + ); + + cancel.cancel(); +} diff --git a/crates/pulsing-actor/tests/unit/transport/client_tests.rs b/crates/pulsing-actor/tests/unit/transport/client_tests.rs new file mode 100644 index 000000000..0db6a10ee --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/client_tests.rs @@ -0,0 +1,521 @@ +//! HTTP/2 Client tests + +use crate::common::fixtures::{StreamingHandler, TestCounters, TestHandler}; +use pulsing_actor::actor::{ActorId, Message}; +use pulsing_actor::transport::{Http2Client, Http2Config, Http2RemoteTransport, Http2Server}; +use std::sync::atomic::Ordering; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +#[tokio::test] +async fn test_http2_client_creation() { + let client = Http2Client::new(Http2Config::default()); + // Client should be clonable + let _cloned = client.clone(); +} + +#[tokio::test] +async fn test_http2_ask_request() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Http2Client::new(Http2Config::default()); + let response = client + .ask(addr, "/actors/test", "TestMsg", b"hello".to_vec()) + .await + .unwrap(); + + let response_str = String::from_utf8_lossy(&response); + assert!(response_str.contains("/actors/test")); + assert!(response_str.contains("TestMsg")); + assert!(response_str.contains("hello")); + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_tell_request() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Http2Client::new(Http2Config::default()); + client + .tell(addr, "/actors/test", "FireAndForget", b"data".to_vec()) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + assert_eq!(counters.tell_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_multiple_requests() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + for i in 0..10 { + let response = client + .ask( + addr, + "/actors/test", + "Msg", + format!("request-{}", i).into_bytes(), + ) + .await + .unwrap(); + + let response_str = String::from_utf8_lossy(&response); + assert!(response_str.contains(&format!("request-{}", i))); + } + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 10); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_concurrent_requests() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Arc::new(Http2Client::new(Http2Config::default())); + + let mut handles = Vec::new(); + for i in 0..20 { + let client = client.clone(); + let handle = tokio::spawn(async move { + client + .ask( + addr, + "/actors/test", + "Concurrent", + format!("req-{}", i).into_bytes(), + ) + .await + }); + handles.push(handle); + } + + let results: Vec<_> = futures::future::join_all(handles).await; + + for result in results { + assert!(result.unwrap().is_ok()); + } + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 20); + + cancel.cancel(); +} + +// ============================================================================ +// Http2RemoteTransport Tests +// ============================================================================ + +#[tokio::test] +async fn test_http2_remote_transport_ask() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Arc::new(Http2Client::new(Http2Config::default())); + let transport = Http2RemoteTransport::new(client, addr, "test-actor".to_string()); + + use pulsing_actor::actor::RemoteTransport; + + let actor_id = ActorId::generate(); + let response = transport + .request(&actor_id, "TestType", b"payload".to_vec()) + .await + .unwrap(); + + let response_str = String::from_utf8_lossy(&response); + assert!(response_str.contains("test-actor")); + assert!(response_str.contains("TestType")); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_remote_transport_tell() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Arc::new(Http2Client::new(Http2Config::default())); + let transport = Http2RemoteTransport::new(client, addr, "fire-actor".to_string()); + + use pulsing_actor::actor::RemoteTransport; + + let actor_id = ActorId::generate(); + transport + .send(&actor_id, "FireMsg", b"data".to_vec()) + .await + .unwrap(); + + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + assert_eq!(counters.tell_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_remote_transport_named_path() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Arc::new(Http2Client::new(Http2Config::default())); + use pulsing_actor::actor::ActorPath; + let path = ActorPath::new("services/llm/worker").unwrap(); + let transport = Http2RemoteTransport::new_named(client, addr, path); + + use pulsing_actor::actor::RemoteTransport; + + let actor_id = ActorId::generate(); + let response = transport + .request(&actor_id, "Inference", b"prompt".to_vec()) + .await + .unwrap(); + + let response_str = String::from_utf8_lossy(&response); + assert!(response_str.contains("services/llm/worker")); + + cancel.cancel(); +} + +// ============================================================================ +// Error Handling Tests +// ============================================================================ + +#[tokio::test] +async fn test_http2_connection_refused() { + let client = Http2Client::new(Http2Config::default()); + + let result = client + .ask( + "127.0.0.1:1".parse().unwrap(), + "/actors/test", + "Test", + vec![], + ) + .await; + + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_http2_server_shutdown() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Http2Client::new(Http2Config::default()); + let result = client.ask(addr, "/actors/test", "Test", vec![]).await; + assert!(result.is_ok()); + + server.shutdown(); + + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; +} + +// ============================================================================ +// Unified Send Message Tests +// ============================================================================ + +#[tokio::test] +async fn test_http2_unified_send_message() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(TestHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Http2Client::new(Http2Config::default()); + let response = client + .send_message(addr, "/actors/test", "TestMsg", b"test-payload".to_vec()) + .await + .unwrap(); + + assert!(response.is_single()); + + let Message::Single { data, .. } = response else { + panic!("Expected single message"); + }; + let data_str = String::from_utf8_lossy(&data); + assert!(data_str.contains("/actors/test")); + assert!(data_str.contains("TestMsg")); + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} + +// ============================================================================ +// Connection Pool Tests +// ============================================================================ + +#[tokio::test] +async fn test_http2_connection_pool_reuse() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + for i in 0..10 { + let _ = client + .ask( + addr, + "/actors/pool-test", + "Msg", + format!("req-{}", i).into_bytes(), + ) + .await + .unwrap(); + } + + let stats = client.stats(); + let created = stats.connections_created.load(Ordering::Relaxed); + let reused = stats.connections_reused.load(Ordering::Relaxed); + + println!("Connection Pool: created={}, reused={}", created, reused); + + assert!(created <= 2, "Too many connections created: {}", created); + assert!(reused >= 8, "Not enough connection reuse: {}", reused); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_retry_on_connection_error() { + use pulsing_actor::transport::{Http2ClientBuilder, RetryConfig}; + use std::time::Duration; + + let client = Http2ClientBuilder::new() + .retry_config(RetryConfig::with_max_retries(2).initial_delay(Duration::from_millis(10))) + .connect_timeout(Duration::from_millis(100)) + .build(); + + let result: Result, _> = client + .ask( + "127.0.0.1:1".parse().unwrap(), + "/actors/test", + "Msg", + vec![], + ) + .await; + + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("Connection") + || err.contains("connect") + || err.contains("timeout") + || err.contains("backing off"), + "Unexpected error: {}", + err + ); +} + +// ============================================================================ +// Streaming Request Tests +// ============================================================================ + +#[tokio::test] +async fn test_http2_streaming_request() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(StreamingHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + let (tx, rx) = tokio::sync::mpsc::channel::>(10); + + tokio::spawn(async move { + for i in 0..5 { + let msg = Message::single("chunk", format!("data-{}", i).into_bytes()); + if tx.send(Ok(msg)).await.is_err() { + break; + } + } + }); + + let stream_msg = Message::from_channel("StreamRequest", rx); + + let response = client + .send_message_full(addr, "/actors/stream_test", stream_msg) + .await + .unwrap(); + + assert!(response.is_single()); + let Message::Single { data, .. } = response else { + panic!("Expected single message"); + }; + let response_str = String::from_utf8_lossy(&data); + assert!(response_str.contains("/actors/stream_test")); + assert!(response_str.contains("collected:")); + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_single_request_with_full_api() { + let counters = Arc::new(TestCounters::default()); + let handler = Arc::new(StreamingHandler::with_counters(counters.clone())); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + let msg = Message::single("TestType", b"test-payload".to_vec()); + let response = client + .send_message_full(addr, "/actors/single_test", msg) + .await + .unwrap(); + + assert!(response.is_single()); + let Message::Single { data, .. } = response else { + panic!("Expected single message"); + }; + let response_str = String::from_utf8_lossy(&data); + assert!(response_str.contains("/actors/single_test")); + assert!(response_str.contains("TestType")); + + assert_eq!(counters.ask_count.load(Ordering::SeqCst), 1); + + cancel.cancel(); +} diff --git a/crates/pulsing-actor/tests/unit/transport/mod.rs b/crates/pulsing-actor/tests/unit/transport/mod.rs new file mode 100644 index 000000000..09aa0cde6 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/mod.rs @@ -0,0 +1,8 @@ +mod benchmark_tests; +mod client_tests; +mod protocol_tests; +mod rest_api_tests; +mod server_tests; +#[cfg(feature = "tls")] +mod tls_tests; +mod tracing_tests; diff --git a/crates/pulsing-actor/tests/unit/transport/protocol_tests.rs b/crates/pulsing-actor/tests/unit/transport/protocol_tests.rs new file mode 100644 index 000000000..95f9e2f92 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/protocol_tests.rs @@ -0,0 +1,77 @@ +//! Stream frame, message mode, and request type tests + +// ============================================================================ +// Stream Frame Tests +// ============================================================================ + +#[test] +fn test_stream_frame_data() { + use pulsing_actor::transport::StreamFrame; + + let frame = StreamFrame::data("token", b"hello"); + assert_eq!(frame.msg_type, "token"); + assert!(!frame.end); + assert!(frame.error.is_none()); + assert_eq!(frame.get_data(), b"hello"); +} + +#[test] +fn test_stream_frame_end() { + use pulsing_actor::transport::StreamFrame; + + let frame = StreamFrame::end(); + assert!(frame.end); + assert!(frame.error.is_none()); +} + +#[test] +fn test_stream_frame_error() { + use pulsing_actor::transport::StreamFrame; + + let frame = StreamFrame::error("something went wrong"); + assert!(frame.end); + assert!(frame.is_error()); + assert_eq!(frame.error.as_ref().unwrap(), "something went wrong"); +} + +#[test] +fn test_stream_frame_binary_roundtrip() { + use pulsing_actor::transport::StreamFrame; + + let original = StreamFrame::data("response", b"world"); + let bytes = original.to_binary(); + let parsed = StreamFrame::from_binary(&bytes).unwrap(); + + assert_eq!(parsed.msg_type, "response"); + assert_eq!(parsed.get_data(), b"world"); +} + +// ============================================================================ +// Message Mode Tests +// ============================================================================ + +#[test] +fn test_message_mode_conversion() { + use pulsing_actor::transport::MessageMode; + + assert_eq!(MessageMode::Ask.as_str(), "ask"); + assert_eq!(MessageMode::Tell.as_str(), "tell"); + assert_eq!(MessageMode::Stream.as_str(), "stream"); + + assert_eq!(MessageMode::parse("ask"), Some(MessageMode::Ask)); + assert_eq!(MessageMode::parse("TELL"), Some(MessageMode::Tell)); + assert_eq!(MessageMode::parse("Stream"), Some(MessageMode::Stream)); + assert_eq!(MessageMode::parse("invalid"), None); +} + +#[test] +fn test_request_type_conversion() { + use pulsing_actor::transport::RequestType; + + assert_eq!(RequestType::Single.as_str(), "single"); + assert_eq!(RequestType::Stream.as_str(), "stream"); + + assert_eq!(RequestType::parse("single"), Some(RequestType::Single)); + assert_eq!(RequestType::parse("STREAM"), Some(RequestType::Stream)); + assert_eq!(RequestType::parse("invalid"), None); +} diff --git a/crates/pulsing-actor/tests/unit/transport/rest_api_tests.rs b/crates/pulsing-actor/tests/unit/transport/rest_api_tests.rs new file mode 100644 index 000000000..a8e728870 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/rest_api_tests.rs @@ -0,0 +1,96 @@ +//! REST API endpoint tests + +use crate::common::fixtures::TestHandler; +use pulsing_actor::transport::{Http2Config, Http2Server, Http2ServerHandler}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +/// Test /cluster/members endpoint +#[tokio::test] +async fn test_cluster_members_endpoint() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let _server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler.clone(), + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let members = handler.cluster_members().await; + assert!(members.is_array()); + let members_array = members.as_array().unwrap(); + assert_eq!(members_array.len(), 2); + + let member = &members_array[0]; + assert!(member.get("node_id").is_some()); + assert!(member.get("addr").is_some()); + assert!(member.get("status").is_some()); + + cancel.cancel(); +} + +/// Test /actors endpoint +#[tokio::test] +async fn test_actors_list_endpoint() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let _server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler.clone(), + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let actors = handler.actors_list(false).await; + assert!(actors.is_array()); + let actors_array = actors.as_array().unwrap(); + assert_eq!(actors_array.len(), 2); + + let actor = &actors_array[0]; + assert!(actor.get("name").is_some()); + assert!(actor.get("type").is_some()); + assert_eq!(actor.get("type").unwrap(), "user"); + + let all_actors = handler.actors_list(true).await; + let all_actors_array = all_actors.as_array().unwrap(); + assert_eq!(all_actors_array.len(), 3); + + cancel.cancel(); +} + +/// Test actor metadata in actors list +#[tokio::test] +async fn test_actors_list_metadata() { + let handler = Arc::new(TestHandler::new()); + + let actors = handler.actors_list(false).await; + let actors_array = actors.as_array().unwrap(); + let actor = &actors_array[0]; + + assert!(actor.get("actor_id").is_some()); + assert!(actor.get("class").is_some()); + assert!(actor.get("module").is_some()); + + assert_eq!(actor.get("class").unwrap(), "Counter"); + assert_eq!(actor.get("module").unwrap(), "__main__"); +} + +/// Test health check endpoint returns expected structure +#[tokio::test] +async fn test_health_check_endpoint() { + let handler = Arc::new(TestHandler::new()); + + let health = handler.health_check().await; + + assert!(health.get("status").is_some()); + assert_eq!(health.get("status").unwrap(), "healthy"); + assert!(health.get("ask_count").is_some()); + assert!(health.get("tell_count").is_some()); +} diff --git a/crates/pulsing-actor/tests/unit/transport/server_tests.rs b/crates/pulsing-actor/tests/unit/transport/server_tests.rs new file mode 100644 index 000000000..d05859c8f --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/server_tests.rs @@ -0,0 +1,92 @@ +//! HTTP/2 Server tests + +use crate::common::fixtures::TestHandler; +use pulsing_actor::transport::{Http2Config, Http2Server}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +#[tokio::test] +async fn test_http2_server_creation() { + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + // Server should be bound to a valid port + assert_ne!(server.local_addr().port(), 0); + + // Cleanup + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_server_multiple_instances() { + let cancel = CancellationToken::new(); + + // Create multiple servers + let mut servers = Vec::new(); + for _ in 0..3 { + let handler = Arc::new(TestHandler::new()); + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + servers.push(server); + } + + // All servers should have different ports + let ports: Vec = servers.iter().map(|s| s.local_addr().port()).collect(); + let unique_ports: std::collections::HashSet<_> = ports.iter().collect(); + assert_eq!(ports.len(), unique_ports.len()); + + // Cleanup + cancel.cancel(); +} + +#[tokio::test] +async fn test_http2_custom_config() { + let config = Http2Config::new() + .max_concurrent_streams(50) + .initial_window_size(1024 * 1024) + .connect_timeout(std::time::Duration::from_secs(10)) + .request_timeout(std::time::Duration::from_secs(60)); + + assert_eq!(config.max_concurrent_streams, 50); + assert_eq!(config.initial_window_size, 1024 * 1024); + assert_eq!(config.connect_timeout, std::time::Duration::from_secs(10)); + assert_eq!(config.request_timeout, std::time::Duration::from_secs(60)); +} + +#[tokio::test] +async fn test_http2_server_with_custom_config() { + let mut config = Http2Config::new().max_concurrent_streams(100); + config.max_frame_size = 32 * 1024; + + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + config, + cancel.clone(), + ) + .await + .unwrap(); + + assert_ne!(server.local_addr().port(), 0); + + // Cleanup + cancel.cancel(); +} diff --git a/crates/pulsing-actor/tests/unit/transport/tls_tests.rs b/crates/pulsing-actor/tests/unit/transport/tls_tests.rs new file mode 100644 index 000000000..2cae66d61 --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/tls_tests.rs @@ -0,0 +1,107 @@ +//! TLS tests (requires `tls` feature) + +use crate::common::fixtures::TestHandler; +use pulsing_actor::transport::http2::TlsConfig; +use pulsing_actor::transport::{Http2Client, Http2Config, Http2Server}; +use std::sync::Arc; +use tokio_util::sync::CancellationToken; + +/// Test TLS configuration creation from passphrase +#[test] +fn test_tls_config_from_passphrase() { + let config = TlsConfig::from_passphrase("test-cluster-password"); + assert!(config.is_ok(), "TLS config creation failed: {:?}", config); +} + +/// Test that same passphrase produces deterministic CA +#[test] +fn test_tls_deterministic_ca() { + let config1 = TlsConfig::from_passphrase("deterministic-test-password").unwrap(); + let config2 = TlsConfig::from_passphrase("deterministic-test-password").unwrap(); + + assert_eq!(config1.passphrase_hash(), config2.passphrase_hash()); +} + +/// Test that different passphrase produces different CA +#[test] +fn test_tls_different_passphrase() { + let config1 = TlsConfig::from_passphrase("password-one").unwrap(); + let config2 = TlsConfig::from_passphrase("password-two").unwrap(); + + assert_ne!(config1.passphrase_hash(), config2.passphrase_hash()); +} + +/// Test TLS-enabled HTTP/2 server and client communication +#[tokio::test] +async fn test_tls_server_client_communication() { + let tls_config = TlsConfig::from_passphrase("test-cluster-tls").unwrap(); + + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let http2_config = Http2Config::default().tls_config(tls_config.clone()); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler.clone(), + http2_config.clone(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client = Http2Client::new(http2_config); + + let response = client + .ask(addr, "/actors/test", "test-msg", b"hello tls".to_vec()) + .await; + + assert!(response.is_ok(), "TLS request failed: {:?}", response); + + let body = response.unwrap(); + let response_str = String::from_utf8_lossy(&body); + assert!( + response_str.contains("hello tls"), + "Response should contain original message" + ); + + cancel.cancel(); +} + +/// Test that different passphrase fails TLS handshake +#[tokio::test] +async fn test_tls_different_passphrase_fails() { + let server_tls = TlsConfig::from_passphrase("server-password").unwrap(); + let client_tls = TlsConfig::from_passphrase("wrong-password").unwrap(); + + let handler = Arc::new(TestHandler::new()); + let cancel = CancellationToken::new(); + + let server_config = Http2Config::default().tls_config(server_tls); + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler, + server_config, + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + + let client_config = Http2Config::default().tls_config(client_tls); + let client = Http2Client::new(client_config); + + let response = client + .ask(addr, "/actors/test", "test", b"test".to_vec()) + .await; + + assert!( + response.is_err(), + "Request with different passphrase should fail" + ); + + cancel.cancel(); +} diff --git a/crates/pulsing-actor/tests/unit/transport/tracing_tests.rs b/crates/pulsing-actor/tests/unit/transport/tracing_tests.rs new file mode 100644 index 000000000..20eb5f5ed --- /dev/null +++ b/crates/pulsing-actor/tests/unit/transport/tracing_tests.rs @@ -0,0 +1,270 @@ +//! Tracing and trace context tests + +use crate::common::fixtures::TestCounters; +use pulsing_actor::actor::Message; +use pulsing_actor::tracing::opentelemetry::trace::TraceContextExt; +use pulsing_actor::tracing::{TraceContext, TRACEPARENT_HEADER}; +use pulsing_actor::transport::{Http2Client, Http2Config, Http2Server, Http2ServerHandler}; +use std::sync::atomic::Ordering; +use std::sync::{Arc, Mutex}; +use tokio_util::sync::CancellationToken; + +/// Handler that captures trace context from incoming requests +struct TracingTestHandler { + counters: Arc, + captured_traces: Arc>>>, +} + +impl TracingTestHandler { + fn new() -> Self { + Self { + counters: Arc::new(TestCounters::default()), + captured_traces: Arc::new(Mutex::new(Vec::new())), + } + } + + #[allow(dead_code)] + fn captured_traces(&self) -> Vec> { + self.captured_traces.lock().unwrap().clone() + } +} + +#[async_trait::async_trait] +impl Http2ServerHandler for TracingTestHandler { + async fn handle_message_simple( + &self, + path: &str, + msg_type: &str, + payload: Vec, + ) -> pulsing_actor::error::Result { + self.counters.ask_count.fetch_add(1, Ordering::SeqCst); + + let trace_ctx = TraceContext::from_current(); + self.captured_traces + .lock() + .unwrap() + .push(trace_ctx.map(|t| t.to_traceparent())); + + let response = format!("{}:{}:{}", path, msg_type, payload.len()); + Ok(Message::single("traced", response.into_bytes())) + } + + async fn handle_tell( + &self, + _path: &str, + _msg_type: &str, + _payload: Vec, + ) -> pulsing_actor::error::Result<()> { + self.counters.tell_count.fetch_add(1, Ordering::SeqCst); + Ok(()) + } + + async fn handle_gossip( + &self, + _payload: Vec, + _peer_addr: std::net::SocketAddr, + ) -> pulsing_actor::error::Result>> { + Ok(None) + } + + fn as_any(&self) -> &dyn std::any::Any { + self + } +} + +#[test] +fn test_trace_context_creation() { + let ctx = TraceContext::default(); + assert_eq!(ctx.trace_id.len(), 32); + assert_eq!(ctx.span_id.len(), 16); + assert_eq!(ctx.trace_flags, 0x01); // Sampled +} + +#[test] +fn test_trace_context_to_traceparent() { + let ctx = TraceContext { + trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), + span_id: "b7ad6b7169203331".to_string(), + trace_flags: 0x01, + trace_state: None, + }; + + let header = ctx.to_traceparent(); + assert_eq!( + header, + "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01" + ); +} + +#[test] +fn test_trace_context_from_traceparent() { + let header = "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-01"; + let ctx = TraceContext::from_traceparent(header).unwrap(); + + assert_eq!(ctx.trace_id, "0af7651916cd43dd8448eb211c80319c"); + assert_eq!(ctx.span_id, "b7ad6b7169203331"); + assert_eq!(ctx.trace_flags, 0x01); +} + +#[test] +fn test_trace_context_roundtrip() { + let original = TraceContext::default(); + let header = original.to_traceparent(); + let parsed = TraceContext::from_traceparent(&header).unwrap(); + + assert_eq!(original.trace_id, parsed.trace_id); + assert_eq!(original.span_id, parsed.span_id); + assert_eq!(original.trace_flags, parsed.trace_flags); +} + +#[test] +fn test_trace_context_child() { + let parent = TraceContext::default(); + let child = parent.child(); + + // Same trace ID + assert_eq!(parent.trace_id, child.trace_id); + // Different span ID + assert_ne!(parent.span_id, child.span_id); + // Same flags + assert_eq!(parent.trace_flags, child.trace_flags); +} + +#[test] +fn test_invalid_traceparent_formats() { + // Too few parts + assert!(TraceContext::from_traceparent("invalid").is_none()); + assert!(TraceContext::from_traceparent("00-abc-def").is_none()); + + // Wrong lengths + assert!(TraceContext::from_traceparent("00-short-id-01").is_none()); + assert!( + TraceContext::from_traceparent("00-0af7651916cd43dd8448eb211c80319c-short-01").is_none() + ); + + // Invalid hex in flags + assert!(TraceContext::from_traceparent( + "00-0af7651916cd43dd8448eb211c80319c-b7ad6b7169203331-zz" + ) + .is_none()); +} + +#[test] +fn test_traceparent_header_constant() { + assert_eq!(TRACEPARENT_HEADER, "traceparent"); +} + +#[test] +fn test_trace_context_not_sampled() { + let ctx = TraceContext { + trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), + span_id: "b7ad6b7169203331".to_string(), + trace_flags: 0x00, // Not sampled + trace_state: None, + }; + + let header = ctx.to_traceparent(); + assert!(header.ends_with("-00")); + + let parsed = TraceContext::from_traceparent(&header).unwrap(); + assert_eq!(parsed.trace_flags, 0x00); +} + +#[test] +fn test_new_child_span_id_uniqueness() { + let ids: Vec = (0..100) + .map(|_| TraceContext::new_child_span_id()) + .collect(); + + // All IDs should be unique + let mut unique_ids = ids.clone(); + unique_ids.sort(); + unique_ids.dedup(); + assert_eq!(ids.len(), unique_ids.len()); + + // All IDs should be 16 hex chars + for id in &ids { + assert_eq!(id.len(), 16); + assert!(id.chars().all(|c| c.is_ascii_hexdigit())); + } +} + +#[tokio::test] +async fn test_http2_request_with_tracing() { + let handler = Arc::new(TracingTestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler.clone(), + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + let response = client + .ask(addr, "/actors/traced", "test", b"hello".to_vec()) + .await + .unwrap(); + + assert!(!response.is_empty()); + + cancel.cancel(); +} + +#[tokio::test] +async fn test_multiple_requests_different_traces() { + let handler = Arc::new(TracingTestHandler::new()); + let cancel = CancellationToken::new(); + + let server = Http2Server::new( + "127.0.0.1:0".parse().unwrap(), + handler.clone(), + Http2Config::default(), + cancel.clone(), + ) + .await + .unwrap(); + + let addr = server.local_addr(); + let client = Http2Client::new(Http2Config::default()); + + for i in 0..3 { + let _ = client + .ask( + addr, + "/actors/test", + "type", + format!("msg-{}", i).into_bytes(), + ) + .await + .unwrap(); + } + + assert_eq!(handler.counters.ask_count.load(Ordering::SeqCst), 3); + + cancel.cancel(); +} + +#[test] +fn test_trace_context_otel_conversion() { + let ctx = TraceContext { + trace_id: "0af7651916cd43dd8448eb211c80319c".to_string(), + span_id: "b7ad6b7169203331".to_string(), + trace_flags: 0x01, + trace_state: None, + }; + + let otel_ctx = ctx.to_otel_context(); + + assert!(!otel_ctx + .span() + .span_context() + .trace_id() + .to_string() + .is_empty()); +} diff --git a/crates/pulsing-bench/src/actors/console_renderer.rs b/crates/pulsing-bench/src/actors/console_renderer.rs index 262c06098..81efa9f08 100644 --- a/crates/pulsing-bench/src/actors/console_renderer.rs +++ b/crates/pulsing-bench/src/actors/console_renderer.rs @@ -343,17 +343,21 @@ fn create_progress_bar(progress: f64) -> String { #[async_trait] impl Actor for ConsoleRendererActor { - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("ConsoleRenderer started with actor_id {:?}", ctx.id()); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("ConsoleRenderer stopped"); Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("DisplayUpdate") { diff --git a/crates/pulsing-bench/src/actors/coordinator.rs b/crates/pulsing-bench/src/actors/coordinator.rs index 573d895b4..6e4a99a78 100644 --- a/crates/pulsing-bench/src/actors/coordinator.rs +++ b/crates/pulsing-bench/src/actors/coordinator.rs @@ -12,6 +12,7 @@ use super::scheduler::{RequestGenerator, SimpleRequestGenerator, TokenizedReques use super::{ConsoleRendererActor, MetricsAggregatorActor, SchedulerActor, WorkerActor}; use crate::tokenizer::TokenCounter; use async_trait::async_trait; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use std::sync::Arc; use std::time::Duration; @@ -486,22 +487,29 @@ impl Default for CoordinatorActor { #[async_trait] impl Actor for CoordinatorActor { - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("Coordinator started with actor_id {:?}", ctx.id()); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("Coordinator stopped"); Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("StartBenchmark") { let start: StartBenchmark = msg.unpack()?; - let result = self.start_benchmark(start).await?; + let result = self + .start_benchmark(start) + .await + .map_err(|e| PulsingError::from(RuntimeError::Other(e.to_string())))?; return Message::pack(&result); } @@ -525,9 +533,14 @@ impl Actor for CoordinatorActor { if let Some(ref report) = self.final_report { return Message::pack(report); } - return Err(anyhow::anyhow!("No report available")); + return Err(PulsingError::from(RuntimeError::Other( + "No report available".into(), + ))); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(PulsingError::from(RuntimeError::Other(format!( + "Unknown message type: {}", + msg_type + )))) } } diff --git a/crates/pulsing-bench/src/actors/metrics_aggregator.rs b/crates/pulsing-bench/src/actors/metrics_aggregator.rs index 24cbbab4d..5616531e2 100644 --- a/crates/pulsing-bench/src/actors/metrics_aggregator.rs +++ b/crates/pulsing-bench/src/actors/metrics_aggregator.rs @@ -8,6 +8,7 @@ use super::messages::*; use async_trait::async_trait; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use std::time::{Duration, Instant}; use tracing::info; @@ -501,18 +502,22 @@ pub struct FinalReport { #[async_trait] impl Actor for MetricsAggregatorActor { - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("MetricsAggregator started with actor_id {:?}", ctx.id()); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.finalize(); info!("MetricsAggregator stopped"); Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("RegisterRun") { @@ -547,7 +552,10 @@ impl Actor for MetricsAggregatorActor { return Message::pack(&self.get_report()); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(PulsingError::from(RuntimeError::Other(format!( + "Unknown message type: {}", + msg_type + )))) } } diff --git a/crates/pulsing-bench/src/actors/scheduler.rs b/crates/pulsing-bench/src/actors/scheduler.rs index dcdafc622..2bf42e6ad 100644 --- a/crates/pulsing-bench/src/actors/scheduler.rs +++ b/crates/pulsing-bench/src/actors/scheduler.rs @@ -3,6 +3,7 @@ use super::messages::*; use crate::tokenizer::TokenCounter; use async_trait::async_trait; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::Arc; @@ -449,12 +450,12 @@ impl Default for SchedulerActor { #[async_trait] impl Actor for SchedulerActor { - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!("Scheduler started with actor_id {:?}", ctx.id()); Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { self.is_active.store(false, Ordering::SeqCst); info!( "Scheduler stopped. Total sent: {}", @@ -463,7 +464,11 @@ impl Actor for SchedulerActor { Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("ConfigureScheduler") { @@ -495,7 +500,10 @@ impl Actor for SchedulerActor { return Message::pack(&self.get_progress()); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(PulsingError::from(RuntimeError::Other(format!( + "Unknown message type: {}", + msg_type + )))) } } diff --git a/crates/pulsing-bench/src/actors/worker.rs b/crates/pulsing-bench/src/actors/worker.rs index c8e5451ae..3f30a97fc 100644 --- a/crates/pulsing-bench/src/actors/worker.rs +++ b/crates/pulsing-bench/src/actors/worker.rs @@ -3,6 +3,7 @@ use super::messages::*; use async_trait::async_trait; use futures_util::StreamExt; +use pulsing_actor::error::{PulsingError, RuntimeError}; use pulsing_actor::prelude::*; use reqwest::Client; use reqwest_eventsource::{Event as SseEvent, EventSource}; @@ -204,7 +205,7 @@ impl WorkerActor { #[async_trait] impl Actor for WorkerActor { - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!( "Worker {} started with actor_id {:?}", self.worker_id, @@ -213,7 +214,7 @@ impl Actor for WorkerActor { Ok(()) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { info!( "Worker {} stopped. Completed: {}, Failed: {}", self.worker_id, @@ -223,7 +224,11 @@ impl Actor for WorkerActor { Ok(()) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let msg_type = msg.msg_type(); if msg_type.ends_with("SendRequest") { @@ -253,7 +258,10 @@ impl Actor for WorkerActor { return Message::pack(&status); } - Err(anyhow::anyhow!("Unknown message type: {}", msg_type)) + Err(PulsingError::from(RuntimeError::Other(format!( + "Unknown message type: {}", + msg_type + )))) } } diff --git a/crates/pulsing-py/src/actor.rs b/crates/pulsing-py/src/actor.rs index 40349461e..7a7fc638b 100644 --- a/crates/pulsing-py/src/actor.rs +++ b/crates/pulsing-py/src/actor.rs @@ -2,7 +2,6 @@ use futures::StreamExt; use pulsing_actor::actor::{ActorId, ActorPath, NodeId}; -use pulsing_actor::error::PulsingError; use pulsing_actor::prelude::*; use pulsing_actor::supervision::{BackoffStrategy, RestartPolicy, SupervisionSpec}; use pyo3::exceptions::{PyRuntimeError, PyStopAsyncIteration, PyValueError}; @@ -14,27 +13,21 @@ use std::sync::Mutex as StdMutex; use tokio::sync::mpsc; use tokio::sync::Mutex as TokioMutex; -use crate::errors::pulsing_error_to_py_err_direct; +use crate::errors::pulsing_error_to_py_err; use crate::python_error_converter::convert_python_exception_to_actor_error; use crate::python_executor::python_executor; /// Special message type identifier for pickle-encoded Python objects const SEALED_PY_MSG_TYPE: &str = "__sealed_py_message__"; -/// Convert error to Python exception -/// Prefer using pulsing_error_to_py_err_direct for PulsingError types -fn to_pyerr(err: E) -> PyErr { - // Try to downcast to PulsingError - let err_str = err.to_string(); - - // For non-PulsingError types, use RuntimeError - // In practice, most errors from pulsing-actor should be PulsingError - PyRuntimeError::new_err(err_str) +/// Convert PulsingError to Python exception (used for actor system APIs that return Result<_, PulsingError>). +fn to_pyerr(err: pulsing_actor::error::PulsingError) -> PyErr { + pulsing_error_to_py_err(err) } -/// Convert PulsingError to Python exception -fn pulsing_to_pyerr(err: PulsingError) -> PyErr { - pulsing_error_to_py_err_direct(err) +/// Convert non-anyhow errors (parse, validation) to Python ValueError. +fn to_py_value_err(err: E) -> PyErr { + PyValueError::new_err(err.to_string()) } /// Python wrapper for NodeId @@ -253,7 +246,7 @@ impl PyMessage { #[staticmethod] fn from_json(py: Python<'_>, msg_type: String, data: PyObject) -> PyResult { let json_value: serde_json::Value = pythonize::depythonize(&data.into_bound(py))?; - let payload = serde_json::to_vec(&json_value).map_err(to_pyerr)?; + let payload = serde_json::to_vec(&json_value).map_err(to_py_value_err)?; Ok(Self { msg_type, payload: Some(payload), @@ -297,7 +290,8 @@ impl PyMessage { fn to_json(&self, py: Python<'_>) -> PyResult { match &self.payload { Some(data) => { - let value: serde_json::Value = serde_json::from_slice(data).map_err(to_pyerr)?; + let value: serde_json::Value = + serde_json::from_slice(data).map_err(to_py_value_err)?; let pyobj = pythonize::pythonize(py, &value)?; Ok(pyobj.into()) } @@ -506,7 +500,7 @@ impl PyStreamReader { #[pyclass(name = "StreamWriter")] pub struct PyStreamWriter { #[allow(clippy::type_complexity)] - sender: Arc>>>>, + sender: Arc>>>>, } #[pymethods] @@ -549,7 +543,11 @@ impl PyStreamWriter { pyo3_async_runtimes::tokio::future_into_py(py, async move { let mut guard = sender.lock().await; if let Some(tx) = guard.take() { - let _ = tx.send(Err(anyhow::anyhow!(msg))).await; + let _ = tx + .send(Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other(msg), + ))) + .await; } Ok(()) }) @@ -572,7 +570,7 @@ pub struct PyStreamMessage { /// Default message type (used when chunk doesn't specify one) default_msg_type: String, #[allow(clippy::type_complexity)] - receiver: Arc>>>>, + receiver: Arc>>>>, } #[pymethods] @@ -583,7 +581,7 @@ impl PyStreamMessage { #[staticmethod] #[pyo3(signature = (msg_type, buffer_size=32))] fn create(msg_type: String, buffer_size: usize) -> (PyStreamMessage, PyStreamWriter) { - let (tx, rx) = mpsc::channel(buffer_size); + let (tx, rx) = mpsc::channel::>(buffer_size); ( PyStreamMessage { default_msg_type: msg_type, @@ -612,7 +610,10 @@ impl PyStreamMessage { enum PyActorResponse { Single(PyMessage), /// Stream of Messages with default msg_type - StreamChannel(String, mpsc::Receiver>), + StreamChannel( + String, + mpsc::Receiver>, + ), /// Pickled Python object for Python-to-Python communication Sealed(Vec), /// Generator (async or sync) to be iterated @@ -739,7 +740,7 @@ impl PySystemConfig { #[staticmethod] fn with_addr(addr: String) -> PyResult { - let socket_addr: SocketAddr = addr.parse().map_err(to_pyerr)?; + let socket_addr: SocketAddr = addr.parse().map_err(to_py_value_err)?; Ok(Self { inner: SystemConfig::with_addr(socket_addr), }) @@ -747,7 +748,7 @@ impl PySystemConfig { fn with_seeds(&self, seeds: Vec) -> PyResult { let seed_addrs: Result, _> = seeds.iter().map(|s| s.parse()).collect(); - let seed_addrs = seed_addrs.map_err(to_pyerr)?; + let seed_addrs = seed_addrs.map_err(to_py_value_err)?; Ok(Self { inner: self.inner.clone().with_seeds(seed_addrs), }) @@ -861,7 +862,7 @@ impl Actor for PythonActorWrapper { }) } - async fn on_start(&mut self, ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_start(&mut self, ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { let handler = Python::with_gil(|py| self.handler.clone_ref(py)); let actor_id = *ctx.id(); let event_loop = Python::with_gil(|py| self.event_loop.clone_ref(py)); @@ -873,7 +874,6 @@ impl Actor for PythonActorWrapper { let py_actor_id = PyActorId { inner: actor_id }; let result = handler.call_method1(py, "on_start", (py_actor_id,))?; - // Check if return value is a coroutine, if so wait for it to complete let asyncio = py.import("asyncio")?; let is_coro = asyncio .call_method1("iscoroutine", (&result,))? @@ -893,11 +893,19 @@ impl Actor for PythonActorWrapper { }) }) .await - .map_err(|e| anyhow::anyhow!("Python executor error: {:?}", e))? - .map_err(|e| anyhow::anyhow!("Python on_start error: {:?}", e)) + .map_err(|e| { + pulsing_actor::error::PulsingError::from(pulsing_actor::error::RuntimeError::Other( + format!("Python executor error: {:?}", e), + )) + })? + .map_err(|e| { + pulsing_actor::error::PulsingError::from(pulsing_actor::error::RuntimeError::Other( + format!("Python on_start error: {:?}", e), + )) + }) } - async fn on_stop(&mut self, _ctx: &mut ActorContext) -> anyhow::Result<()> { + async fn on_stop(&mut self, _ctx: &mut ActorContext) -> pulsing_actor::error::Result<()> { let handler = Python::with_gil(|py| self.handler.clone_ref(py)); let event_loop = Python::with_gil(|py| self.event_loop.clone_ref(py)); @@ -907,7 +915,6 @@ impl Actor for PythonActorWrapper { if handler.getattr(py, "on_stop").is_ok() { let result = handler.call_method0(py, "on_stop")?; - // Check if return value is a coroutine, if so wait for it to complete let asyncio = py.import("asyncio")?; let is_coro = asyncio .call_method1("iscoroutine", (&result,))? @@ -927,11 +934,23 @@ impl Actor for PythonActorWrapper { }) }) .await - .map_err(|e| anyhow::anyhow!("Python executor error: {:?}", e))? - .map_err(|e| anyhow::anyhow!("Python on_stop error: {:?}", e)) + .map_err(|e| { + pulsing_actor::error::PulsingError::from(pulsing_actor::error::RuntimeError::Other( + format!("Python executor error: {:?}", e), + )) + })? + .map_err(|e| { + pulsing_actor::error::PulsingError::from(pulsing_actor::error::RuntimeError::Other( + format!("Python on_stop error: {:?}", e), + )) + }) } - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let (handler, event_loop) = Python::with_gil(|py| (self.handler.clone_ref(py), self.event_loop.clone_ref(py))); @@ -1049,19 +1068,26 @@ impl Actor for PythonActorWrapper { }) }) .await - .map_err(|e| anyhow::anyhow!("Python executor error: {:?}", e))?; + .map_err(|e| { + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other( + format!("Python executor error: {:?}", e), + ), + ) + })?; // Convert Python exceptions to ActorError let response = match response { Ok(resp) => resp, Err(py_err) => { - // Convert Python exception to ActorError Python::with_gil(|py| { - let actor_err = convert_python_exception_to_actor_error(py, &py_err)?; - // Convert ActorError to PulsingError and then to anyhow::Error - Err(anyhow::Error::from( - pulsing_actor::error::PulsingError::from(actor_err), - )) + let actor_err = + convert_python_exception_to_actor_error(py, &py_err).map_err(|e| { + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other(e.to_string()), + ) + })?; + Err(pulsing_actor::error::PulsingError::from(actor_err)) }) }?, }; @@ -1074,7 +1100,7 @@ impl Actor for PythonActorWrapper { PyActorResponse::Sealed(data) => Ok(Message::single(SEALED_PY_MSG_TYPE, data)), PyActorResponse::Generator(generator, event_loop, is_async) => { // Create channel for streaming generator values - let (tx, rx) = mpsc::channel(32); + let (tx, rx) = mpsc::channel::>(32); // Spawn background task to iterate generator tokio::spawn(async move { @@ -1106,10 +1132,13 @@ impl Actor for PythonActorWrapper { if e.is_instance_of::(py) { break; } - let _ = tx.blocking_send(Err(anyhow::anyhow!( - "Generator error: {}", - e - ))); + let _ = tx.blocking_send(Err( + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other( + format!("Generator error: {}", e), + ), + ), + )); break; } } @@ -1131,10 +1160,13 @@ impl Actor for PythonActorWrapper { if e.is_instance_of::(py) { break; } - let _ = tx.blocking_send(Err(anyhow::anyhow!( - "Generator error: {}", - e - ))); + let _ = tx.blocking_send(Err( + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other( + format!("Generator error: {}", e), + ), + ), + )); break; } } @@ -1173,9 +1205,7 @@ impl PyActorSystem { ) -> PyResult> { let config_inner = config.inner; pyo3_async_runtimes::tokio::future_into_py(py, async move { - let system = ActorSystem::new(config_inner) - .await - .map_err(|e| pulsing_to_pyerr(PulsingError::from(e)))?; + let system = ActorSystem::new(config_inner).await.map_err(to_pyerr)?; Ok(PyActorSystem { inner: system, event_loop, @@ -1344,9 +1374,9 @@ impl PyActorSystem { // Parse the path - use new_system for system/* paths (internal use only) let path = if name.starts_with("system/") { - ActorPath::new_system(&name).map_err(to_pyerr)? + ActorPath::new_system(&name).map_err(to_py_value_err)? } else { - ActorPath::new(&name).map_err(to_pyerr)? + ActorPath::new(&name).map_err(to_py_value_err)? }; if matches!(policy, RestartPolicy::Never) { @@ -1363,13 +1393,20 @@ impl PyActorSystem { } else { // actor is a factory - named actor with supervision let factory = move || { - Python::with_gil(|py| -> anyhow::Result { - let event_loop = event_loop.clone_ref(py); - let instance = actor.call0(py).map_err(|e| { - anyhow::anyhow!("Python factory error: {:?}", e) - })?; - Ok(PythonActorWrapper::new(instance, event_loop)) - }) + Python::with_gil( + |py| -> pulsing_actor::error::Result { + let event_loop = event_loop.clone_ref(py); + let instance = actor.call0(py).map_err(|e| { + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other(format!( + "Python factory error: {:?}", + e + )), + ) + })?; + Ok(PythonActorWrapper::new(instance, event_loop)) + }, + ) }; system .spawning() @@ -1444,9 +1481,9 @@ impl PyActorSystem { }; // Use new_system for system/* paths (internal use) let path = if name.starts_with("system/") { - ActorPath::new_system(&name).map_err(to_pyerr)? + ActorPath::new_system(&name).map_err(to_py_value_err)? } else { - ActorPath::new(&name).map_err(to_pyerr)? + ActorPath::new(&name).map_err(to_py_value_err)? }; let instances = system.get_named_instances_detailed(&path).await; let result: Vec> = instances @@ -1581,9 +1618,9 @@ impl PyActorSystem { }; // Use new_system for system/* paths (internal use) let path = if name.starts_with("system/") { - ActorPath::new_system(&name).map_err(to_pyerr)? + ActorPath::new_system(&name).map_err(to_py_value_err)? } else { - ActorPath::new(&name).map_err(to_pyerr)? + ActorPath::new(&name).map_err(to_py_value_err)? }; let node = node_id.map(NodeId::new); let actor_ref = system @@ -1639,7 +1676,7 @@ impl PyActorSystem { pyo3_async_runtimes::tokio::future_into_py(py, async move { // Use system/core - the correct system actor path - let path = ActorPath::new_system("system/core").map_err(to_pyerr)?; + let path = ActorPath::new_system("system/core").map_err(to_py_value_err)?; let actor_ref = system .resolve_named(&path, Some(&NodeId::new(node_id))) .await diff --git a/crates/pulsing-py/src/errors.rs b/crates/pulsing-py/src/errors.rs index 680bec7fc..b777cd3f6 100644 --- a/crates/pulsing-py/src/errors.rs +++ b/crates/pulsing-py/src/errors.rs @@ -1,54 +1,56 @@ //! Python exception bindings for Pulsing errors //! -//! This module converts Rust error types to Python exceptions. -//! Due to PyO3 abi3 limitations, we use PyRuntimeError as the base -//! and let Python layer re-raise as appropriate exception types. +//! This module converts Rust error types to Python exceptions using +//! JSON-structured error envelopes instead of string prefixes. +//! +//! The JSON envelope format: +//! ```json +//! { +//! "category": "actor" | "runtime", +//! // For actor errors (category="actor"): +//! "error": { "type": "business", "code": 400, ... }, +//! // For runtime errors (category="runtime"): +//! "kind": "actor_not_found", +//! "message": "Actor not found: my-actor", +//! "actor_name": "my-actor" // optional +//! } +//! ``` -use pulsing_actor::error::{PulsingError, RuntimeError}; +use pulsing_actor::error::PulsingError; use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -/// Convert Rust PulsingError to appropriate Python exception +/// JSON marker prefix for structured error envelopes. +/// Python layer detects this prefix and parses the JSON payload. +pub const ERROR_ENVELOPE_PREFIX: &str = "__PULSING_ERROR__:"; + +/// Convert Rust PulsingError to Python exception using JSON envelope. /// -/// This function prefixes error messages with error type markers so Python -/// layer can identify and re-raise as appropriate exception types. +/// Instead of string-prefix-based encoding (fragile, requires regex parsing), +/// this uses a JSON-structured envelope that Python can reliably decode. pub fn pulsing_error_to_py_err(err: PulsingError) -> PyErr { - let err_msg = err.to_string(); - - match &err { - // Actor errors (user code errors) -> prefix with "ACTOR_ERROR:" - PulsingError::Actor(_actor_err) => { - PyRuntimeError::new_err(format!("ACTOR_ERROR:{}", err_msg)) + let json_str = match &err { + PulsingError::Actor(actor_err) => { + // ActorError already derives Serialize with serde(tag = "type") + let actor_json = serde_json::to_value(actor_err).unwrap_or_else(|_| { + serde_json::json!({"type": "system", "error": err.to_string(), "recoverable": true}) + }); + serde_json::json!({ + "category": "actor", + "error": actor_json, + }) + .to_string() } - // Runtime errors (framework errors) -> prefix with "RUNTIME_ERROR:" - PulsingError::Runtime(runtime_err) => { - // Extract actor name if available for runtime errors - let actor_name = match runtime_err { - RuntimeError::ActorNotFound { name } => Some(name.clone()), - RuntimeError::ActorAlreadyExists { name } => Some(name.clone()), - RuntimeError::ActorNotLocal { name } => Some(name.clone()), - RuntimeError::ActorStopped { name } => Some(name.clone()), - RuntimeError::ActorMailboxFull { name } => Some(name.clone()), - RuntimeError::InvalidActorPath { path: _ } => None, - RuntimeError::MessageTypeMismatch { .. } => None, - RuntimeError::ActorSpawnFailed { .. } => None, - _ => None, - }; - - let full_msg = if let Some(ref name) = actor_name { - format!("RUNTIME_ERROR:{}:actor={}", err_msg, name) - } else { - format!("RUNTIME_ERROR:{}", err_msg) - }; - - PyRuntimeError::new_err(full_msg) - } - } -} + PulsingError::Runtime(runtime_err) => serde_json::json!({ + "category": "runtime", + "kind": runtime_err.kind(), + "message": runtime_err.to_string(), + "actor_name": runtime_err.actor_name(), + }) + .to_string(), + }; -/// Convert PulsingError to Python exception (preferred method) -pub fn pulsing_error_to_py_err_direct(err: PulsingError) -> PyErr { - pulsing_error_to_py_err(err) + PyRuntimeError::new_err(format!("{}{}", ERROR_ENVELOPE_PREFIX, json_str)) } /// Add error classes to Python module @@ -60,3 +62,52 @@ pub fn add_to_module(_m: &Bound<'_, PyModule>) -> PyResult<()> { // Error classes are defined in Python layer Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + use pulsing_actor::error::{ActorError, RuntimeError}; + + #[test] + fn test_actor_error_envelope() { + let err = PulsingError::Actor(ActorError::business(400, "Invalid input", None)); + let py_err = pulsing_error_to_py_err(err); + let msg = py_err.to_string(); + assert!(msg.starts_with(ERROR_ENVELOPE_PREFIX)); + + let json_str = &msg[ERROR_ENVELOPE_PREFIX.len()..]; + let envelope: serde_json::Value = serde_json::from_str(json_str).unwrap(); + assert_eq!(envelope["category"], "actor"); + assert_eq!(envelope["error"]["type"], "business"); + assert_eq!(envelope["error"]["code"], 400); + } + + #[test] + fn test_runtime_error_envelope() { + let err = PulsingError::Runtime(RuntimeError::actor_not_found("my-actor")); + let py_err = pulsing_error_to_py_err(err); + let msg = py_err.to_string(); + assert!(msg.starts_with(ERROR_ENVELOPE_PREFIX)); + + let json_str = &msg[ERROR_ENVELOPE_PREFIX.len()..]; + let envelope: serde_json::Value = serde_json::from_str(json_str).unwrap(); + assert_eq!(envelope["category"], "runtime"); + assert_eq!(envelope["kind"], "actor_not_found"); + assert_eq!(envelope["actor_name"], "my-actor"); + } + + #[test] + fn test_anyhow_error_conversion() { + let anyhow_err = anyhow::anyhow!("something went wrong"); + let py_err = pulsing_error_to_py_err(PulsingError::from(RuntimeError::Other( + anyhow_err.to_string(), + ))); + let msg = py_err.to_string(); + assert!(msg.starts_with(ERROR_ENVELOPE_PREFIX)); + + let json_str = &msg[ERROR_ENVELOPE_PREFIX.len()..]; + let envelope: serde_json::Value = serde_json::from_str(json_str).unwrap(); + assert_eq!(envelope["category"], "runtime"); + assert_eq!(envelope["kind"], "other"); + } +} diff --git a/docs/design/name-only-resolve.md b/docs/design/name-only-resolve.md new file mode 100644 index 000000000..450f03dc2 --- /dev/null +++ b/docs/design/name-only-resolve.md @@ -0,0 +1,108 @@ +# 按名字解析 Actor、不依赖具体类型(Name-Only Resolve) + +**已实现**:`resolve(name)` 返回带 `.as_any()` 的包装;`ref.as_any()` 或 `as_any(ref)` 得到可转发任意方法调用的 proxy。 + +## 问题 + +当前要「按名字拿到可调用的 actor」有两种方式: + +1. **类型化解析**:`await SomeClass.resolve("channel.discord")` + - 得到带类型的 `ActorProxy`,可 `await proxy.send_text(...)` + - 调用方必须知道并 import 具体类(如 `DiscordChannel`) + +2. **底层解析**:`await pul.resolve("channel.discord")` + - 得到 `ActorRef`,只能 `ref.ask(msg)`,不能 `proxy.method(...)` + - 调用方要自己拼装 `__call__` / 协议格式 + +像 nanobot 这种「按 channel 名字发消息」的场景,调用方只知道名字(如 `"discord"`),不想依赖具体 channel 类型,因此只能自己维护 `name -> 类` 的映射再做 `XxxChannel.resolve(name)`。希望框架能提供「按名字解析 + 直接调方法」的能力。 + +--- + +## 现有实现要点 + +- `pul.resolve(name)` 已存在,返回 `ActorRef`。 +- `ActorProxy(actor_ref, method_names, async_methods)`: + - `method_names is None` 时,`__getattr__` 不校验名字,任意属性都会返回 `_MethodCaller`(即**已支持「任意方法名」**)。 + - 调用通过现有协议发到 actor(`__call__` / method + args/kwargs),actor 端用 `getattr(instance, method)` 分发,无需改动。 +- 区别只在于:**谁构造 Proxy**(带不带类型信息)、以及**是否知道哪些方法是 async**(影响是否走 streaming 路径)。 + +因此「按名字解析并支持动态方法调用」不需要改协议或 actor 实现,只需要在**解析 + 构造 Proxy** 这一层提供新 API,并在「未知类型」时约定 async 语义。 + +--- + +## 方案 + +### 方案 A:`get_actor(name)` → 动态 Proxy(推荐) + +- **API**:`proxy = await pul.get_actor("channel.discord")`,返回一个「无类型」的 `ActorProxy`。 +- **实现**: + - `get_actor(name)` 内部:`ref = await resolve(name)`,然后 `return ActorProxy(ref, method_names=None, async_methods=?)`。 + - 当 `method_names is None` 时,现有逻辑已允许任意方法名,无需改 `__getattr__`。 +- **async 语义**(二选一): + - **A1**:`async_methods=None` 时**全部按 async 处理**(即 `__getattr__` 里对 `async_methods is None` 时令 `is_async=True`)。 + 这样 `await proxy.send_text(...)` 和 `async for x in proxy.generate(...)` 都能用,适用面最大。 + - **A2**:全部按 sync 处理(当前 `async_methods=set()` 的行为)。 + `await proxy.send_text(...)` 仍可用(因为 `_sync_call` 本身是 async),但**流式返回**(async generator)可能只拿到最终结果或行为未定义,需文档说明「流式请用类型化 resolve」。 +- **推荐 A1**:实现简单(在 Proxy 里把 `async_methods is None` 视为「全部 async」),且与「未知类型时尽量不限制能力」一致。 + +**优点**:不增加新类型、不碰协议;调用方只需 `await pul.get_actor(name)` 然后 `await proxy.xxx(...)`,nanobot 可删掉 `get_channel_actor` 和 name→Class 映射。 +**缺点**:无类型提示、无静态校验;流式在 A2 下需用类型化 resolve。 + +--- + +### 方案 B:`resolve(name)` 返回「默认动态 Proxy」 + +- 保持 `resolve(name)` 返回 `ActorRef` 的语义,**新增**一个重载或单独函数,例如 `get_proxy(name)` / `actor(name)`,行为同方案 A 的 `get_actor(name)`。 +- 或者:为 `ActorRef` 增加 `.proxy(dynamic=True)`,例如 `ref = await resolve(name); proxy = ref.proxy(dynamic=True)`,内部用 `ActorProxy(ref, None, None)` 并约定 async 语义(同 A1/A2)。 + +**优点**:与现有 `resolve` 返回 `ActorRef` 的语义不冲突;需要底层 ref 时仍用 `resolve`。 +**缺点**:多一个 API 或概念(dynamic proxy),需要文档说明与 `SomeClass.resolve` 的差异。 + +--- + +### 方案 C:通过元数据 / Describe 协议补全类型信息 + +- 在 spawn 时把 actor 的「公共方法名 + 是否 async」存到某处(进程内 registry 或通过某 Describe 协议向 actor 查询)。 +- `get_actor(name)` 时先 `resolve(name)` 得到 ref,再查元数据得到 `(method_names, async_methods)`,用现有 `ActorProxy(ref, method_names, async_methods)` 构造**类型信息完整**的 Proxy。 + +**优点**:能区分 sync/async、流式,且不「全部当 async」;理论上可做更好提示。 +**缺点**: +- 元数据:若只存在 spawn 端,则跨进程 resolve 时拿不到;若要从 actor 查,需要约定 Describe 协议和实现。 +- 实现和运维成本高,对 nanobot 这种「只按名字调几个方法」的场景收益有限。 + +--- + +## 建议 + +- **短期**:采用 **方案 A(`get_actor(name)` + 动态 Proxy)**,并在 Proxy 内采用 **A1**(`async_methods is None` 时全部按 async),这样: + - 只在一个地方(如 `pulsing.actor`)增加 `get_actor(name)`(及可选 `node_id`)。 + - 对 `ActorProxy` 做最小改动:在 `__getattr__` 里当 `self._async_methods is None` 时令 `is_async=True`。 +- **命名**:`get_actor(name)` 与现有 `resolve(name)`(返回 ref)区分清晰;若希望更短,可再提供 `pul.actor(name)` 作为别名。 +- **文档**:说明「无类型、无补全;流式优先用类型化 resolve」即可。 +- 方案 C 可作为后续「可观测性 / 接口发现」的一部分再考虑,不必绑在「按名字解析」的第一版里。 + +--- + +## 使用示例(已实现) + +```python +# 按名字解析后通过 .as_any() 获得「任意方法转发」的 proxy +import pulsing as pul + +ref = await pul.resolve("channel.discord") +proxy = ref.as_any() +await proxy.send_text(chat_id, content) +``` + +```python +# 或使用独立函数(适用于已有 ActorRef 的场景) +proxy = pul.as_any(ref) +await proxy.send_text(chat_id, content) +``` + +```python +# 类型化 proxy 也可 .as_any() 得到无类型视图 +typed = await SomeClass.resolve("my_actor") +any_proxy = typed.as_any() +await any_proxy.any_method(...) +``` diff --git a/examples/rust/actor_benchmark.rs b/examples/rust/actor_benchmark.rs index 4fea553cb..1bfe4aa05 100644 --- a/examples/rust/actor_benchmark.rs +++ b/examples/rust/actor_benchmark.rs @@ -17,7 +17,7 @@ use pulsing_bench::{run_benchmark, BenchmarkArgs}; #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { // Parse command line arguments (simplified) let args: Vec = std::env::args().collect(); diff --git a/examples/rust/behavior_counter.rs b/examples/rust/behavior_counter.rs index 904f6abe7..fc60568da 100644 --- a/examples/rust/behavior_counter.rs +++ b/examples/rust/behavior_counter.rs @@ -14,7 +14,7 @@ fn counter(init: i32) -> Behavior { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { let system = ActorSystem::builder().build().await?; // Behavior implements IntoActor, can be passed directly to spawn_named diff --git a/examples/rust/behavior_fsm.rs b/examples/rust/behavior_fsm.rs index ddaacda9c..d67729a59 100644 --- a/examples/rust/behavior_fsm.rs +++ b/examples/rust/behavior_fsm.rs @@ -94,7 +94,7 @@ fn yellow(stats: Stats) -> Behavior { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { println!("=== Traffic Light State Machine ===\n"); let system = ActorSystem::builder().build().await?; diff --git a/examples/rust/cluster.rs b/examples/rust/cluster.rs index b0f98a340..79532a81a 100644 --- a/examples/rust/cluster.rs +++ b/examples/rust/cluster.rs @@ -16,7 +16,11 @@ struct Counter { #[async_trait] impl Actor for Counter { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let n: i32 = msg.unpack()?; self.count += n; println!("[{}] +{} -> {}", self.node_id, n, self.count); @@ -25,7 +29,7 @@ impl Actor for Counter { } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { tracing_subscriber::fmt().with_env_filter("info").init(); let args: Vec = std::env::args().collect(); diff --git a/examples/rust/message_patterns.rs b/examples/rust/message_patterns.rs index 7db7c368c..75d8a7c16 100644 --- a/examples/rust/message_patterns.rs +++ b/examples/rust/message_patterns.rs @@ -14,7 +14,11 @@ struct Demo; #[async_trait] impl Actor for Demo { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { match msg.msg_type() { // Pattern 1: RPC - String in, String out "echo" => { @@ -38,7 +42,9 @@ impl Actor for Demo { // Pattern 3: Client Streaming - sum stream of i32 "sum" => { let Message::Stream { mut stream, .. } = msg else { - return Err(anyhow::anyhow!("Expected stream")); + return Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other("Expected stream".into()), + )); }; let mut total = 0i32; while let Some(chunk) = stream.next().await { @@ -48,13 +54,15 @@ impl Actor for Demo { Message::pack(&total) } - _ => Err(anyhow::anyhow!("Unknown: {}", msg.msg_type())), + _ => Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other(format!("Unknown: {}", msg.msg_type())), + )), } } } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { println!("=== Message Patterns ===\n"); let system = ActorSystem::builder().build().await?; @@ -67,9 +75,18 @@ async fn main() -> anyhow::Result<()> { // Pattern 2: Server Streaming println!("--- Server Streaming ---"); - let req = Message::single("count", bincode::serialize(&3i32)?); + let req = Message::single( + "count", + bincode::serialize(&3i32).map_err(|e| { + pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Serialization(e.to_string()), + ) + })?, + ); let Message::Stream { mut stream, .. } = actor.send(req).await? else { - return Err(anyhow::anyhow!("Expected stream")); + return Err(pulsing_actor::error::PulsingError::from( + pulsing_actor::error::RuntimeError::Other("Expected stream".into()), + )); }; while let Some(chunk) = stream.next().await { let n: i32 = chunk?.unpack()?; diff --git a/examples/rust/named_actors.rs b/examples/rust/named_actors.rs index 6a354ed6f..a0de1a22a 100644 --- a/examples/rust/named_actors.rs +++ b/examples/rust/named_actors.rs @@ -11,14 +11,18 @@ struct Echo; #[async_trait] impl Actor for Echo { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let s: String = msg.unpack()?; Message::pack(&format!("echo: {}", s)) } } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { let system = ActorSystem::builder().build().await?; // Spawn named actor - name is now the full path diff --git a/examples/rust/ping_pong.rs b/examples/rust/ping_pong.rs index 5e93035e2..b0124ca17 100644 --- a/examples/rust/ping_pong.rs +++ b/examples/rust/ping_pong.rs @@ -8,14 +8,18 @@ struct Echo; #[async_trait] impl Actor for Echo { - async fn receive(&mut self, msg: Message, _ctx: &mut ActorContext) -> anyhow::Result { + async fn receive( + &mut self, + msg: Message, + _ctx: &mut ActorContext, + ) -> pulsing_actor::error::Result { let s: String = msg.unpack()?; Message::pack(&format!("echo: {}", s)) } } #[tokio::main] -async fn main() -> anyhow::Result<()> { +async fn main() -> pulsing_actor::error::Result<()> { let system = ActorSystem::builder().build().await?; let echo = system.spawn_named("test/echo", Echo).await?; diff --git a/python/pulsing/__init__.py b/python/pulsing/__init__.py index f0854a3eb..d6a55eddf 100644 --- a/python/pulsing/__init__.py +++ b/python/pulsing/__init__.py @@ -69,10 +69,12 @@ def incr(self): self.value += 1; return self.value remote, # Resolve function resolve, + as_any, # Types Actor, ActorSystem as _ActorSystem, ActorRef, + ActorRefView, ActorId, ActorProxy, Message, @@ -273,6 +275,7 @@ async def refer(actorid: ActorId | str) -> ActorRef: "spawn", "refer", "resolve", + "as_any", "get_system", "is_initialized", # Decorator @@ -281,6 +284,7 @@ async def refer(actorid: ActorId | str) -> ActorRef: "Actor", "ActorSystem", "ActorRef", + "ActorRefView", "ActorId", "ActorProxy", "Message", diff --git a/python/pulsing/actor/__init__.py b/python/pulsing/actor/__init__.py index 7a6893260..bea6a7dad 100644 --- a/python/pulsing/actor/__init__.py +++ b/python/pulsing/actor/__init__.py @@ -190,9 +190,11 @@ async def tell_with_timeout( PYTHON_ACTOR_SERVICE_NAME, ActorClass, ActorProxy, + ActorRefView, PythonActorService, PythonActorServiceProxy, SystemActorProxy, + as_any, get_metrics, get_node_info, get_python_actor_service, @@ -230,8 +232,10 @@ async def tell_with_timeout( "SystemConfig", "ActorSystem", "ActorRef", + "ActorRefView", "ActorId", "ActorProxy", + "as_any", "SystemActorProxy", # Service (for actor_system function) "PythonActorService", diff --git a/python/pulsing/actor/remote.py b/python/pulsing/actor/remote.py index 5a33b0f7d..4b07a9be9 100644 --- a/python/pulsing/actor/remote.py +++ b/python/pulsing/actor/remote.py @@ -22,6 +22,21 @@ def _get_protocol_version() -> int: return _DEFAULT_PROTOCOL_VERSION +def _consume_task_exception(task: asyncio.Task) -> None: + """Consume exception from background task to avoid 'Task exception was never retrieved'.""" + try: + task.result() + except asyncio.CancelledError: + pass + except (RuntimeError, OSError, ConnectionError) as e: + if "closed" in str(e).lower() or "stream" in str(e).lower(): + logging.getLogger(__name__).debug("Stream closed before response: %s", e) + else: + logging.getLogger(__name__).exception("Stream task failed: %s", e) + except Exception: + logging.getLogger(__name__).exception("Stream task failed") + + def _detect_protocol_version(msg: dict) -> int: """Auto-detect protocol version from message. @@ -154,19 +169,22 @@ def _unwrap_response(resp: dict) -> tuple[Any, str | None]: return (resp.get("__result__"), None) +_PULSING_ERROR_PREFIX = "__PULSING_ERROR__:" + + def _convert_rust_error(err: RuntimeError) -> Exception: """Convert Rust-raised RuntimeError to appropriate Pulsing exception. - Rust layer prefixes error messages with markers: - - "ACTOR_ERROR:" -> PulsingActorError (or specific subclasses) - - "RUNTIME_ERROR:" -> PulsingRuntimeError + Rust layer encodes errors as JSON envelopes with prefix "__PULSING_ERROR__:". + The JSON format: + Actor errors: {"category": "actor", "error": {"type": "business", "code": 400, ...}} + Runtime errors: {"category": "runtime", "kind": "actor_not_found", "message": "...", ...} - The error message format for ActorError: - - "ACTOR_ERROR:Business error [code]: message" -> PulsingBusinessError - - "ACTOR_ERROR:System error: message" -> PulsingSystemError - - "ACTOR_ERROR:Timeout: operation 'op' timed out..." -> PulsingTimeoutError - - "ACTOR_ERROR:Unsupported operation: op" -> PulsingUnsupportedError + This replaces the previous regex-based string prefix parsing with + reliable JSON deserialization. """ + import json + from pulsing.exceptions import ( PulsingBusinessError, PulsingSystemError, @@ -176,51 +194,60 @@ def _convert_rust_error(err: RuntimeError) -> Exception: err_msg = str(err) - if err_msg.startswith("ACTOR_ERROR:"): - msg = err_msg.replace("ACTOR_ERROR:", "") + if not err_msg.startswith(_PULSING_ERROR_PREFIX): + # Not a structured Pulsing error, wrap as generic RuntimeError + return PulsingRuntimeError(err_msg) - # Try to identify specific ActorError type from message - if msg.startswith("Business error ["): - # Extract code, message, and details from "Business error [code]: message" - import re + json_str = err_msg[len(_PULSING_ERROR_PREFIX) :] + try: + envelope = json.loads(json_str) + except (json.JSONDecodeError, ValueError): + # JSON parse failed, fall back to generic error + return PulsingRuntimeError(err_msg) - match = re.match(r"Business error \[(\d+)\]: (.+)", msg) - if match: - code = int(match.group(1)) - message = match.group(2) - return PulsingBusinessError(code, message) + category = envelope.get("category") - if msg.startswith("System error: "): - # Extract error message from "System error: message" - error_msg = msg.replace("System error: ", "") - # Default to recoverable=True (we don't have recoverable flag in message) - return PulsingSystemError(error_msg, recoverable=True) + if category == "actor": + actor_err = envelope.get("error", {}) + err_type = actor_err.get("type") - if msg.startswith("Timeout: operation '"): - # Extract operation and duration from "Timeout: operation 'op' timed out after Xms" - import re + if err_type == "business": + code = actor_err.get("code", 0) + message = actor_err.get("message", "Unknown error") + details = actor_err.get("details") + return PulsingBusinessError(code, message, details=details) - match = re.match( - r"Timeout: operation '([^']+)' timed out after (\d+)ms", msg - ) - if match: - operation = match.group(1) - duration_ms = int(match.group(2)) - return PulsingTimeoutError(operation, duration_ms) - - if msg.startswith("Unsupported operation: "): - # Extract operation from "Unsupported operation: op" - operation = msg.replace("Unsupported operation: ", "") + if err_type == "system": + error = actor_err.get("error", "Unknown error") + recoverable = actor_err.get("recoverable", True) + return PulsingSystemError(error, recoverable=recoverable) + + if err_type == "timeout": + operation = actor_err.get("operation", "unknown") + duration_ms = actor_err.get("duration_ms", 0) + return PulsingTimeoutError(operation, duration_ms) + + if err_type == "unsupported": + operation = actor_err.get("operation", "unknown") return PulsingUnsupportedError(operation) - # Fallback: generic PulsingActorError - return PulsingActorError(msg) - elif err_msg.startswith("RUNTIME_ERROR:"): - msg = err_msg.replace("RUNTIME_ERROR:", "") - return PulsingRuntimeError(msg) - else: - # Unknown format, wrap as RuntimeError - return PulsingRuntimeError(err_msg) + # Unknown actor error type, generic fallback + return PulsingActorError(str(actor_err)) + + if category == "runtime": + message = envelope.get("message", "Unknown runtime error") + return PulsingRuntimeError(message) + + # Unknown category + return PulsingRuntimeError(err_msg) + + +async def _ask_convert_errors(ref, msg) -> Any: + """Call ref.ask(msg) and convert Rust RuntimeError to Pulsing exceptions.""" + try: + return await ref.ask(msg) + except RuntimeError as e: + raise _convert_rust_error(e) from e logger = logging.getLogger(__name__) @@ -274,6 +301,27 @@ def get_actor_metadata(name: str) -> dict[str, str] | None: return _actor_metadata_registry.get(name) +class ActorRefView: + """Wrapper around ActorRef that adds .as_any() for an untyped proxy. + + Returned by resolve(name). Delegates .ask(), .tell(), and other + ActorRef attributes to the underlying ref. Use .as_any() to get + a proxy that forwards any method call to the remote actor. + """ + + __slots__ = ("_ref",) + + def __init__(self, ref: ActorRef): + self._ref = ref + + def as_any(self) -> "ActorProxy": + """Return an untyped proxy that forwards any method call to the remote actor.""" + return ActorProxy(self._ref, method_names=None, async_methods=None) + + def __getattr__(self, name: str): + return getattr(self._ref, name) + + PYTHON_ACTOR_SERVICE_NAME = "system/python_actor_service" @@ -288,16 +336,22 @@ def __init__( ): self._ref = actor_ref self._method_names = set(method_names) if method_names else None - self._async_methods = async_methods or set() + # None means "any proxy": allow any method, treat all as async (streaming support) + self._async_methods = async_methods def __getattr__(self, name: str): if name.startswith("_"): raise AttributeError(f"Cannot access private attribute: {name}") if self._method_names is not None and name not in self._method_names: raise AttributeError(f"No method '{name}'") - is_async = name in self._async_methods + # When _async_methods is None (any proxy), treat all methods as async + is_async = self._async_methods is None or name in self._async_methods return _MethodCaller(self._ref, name, is_async=is_async) + def as_any(self) -> "ActorProxy": + """Return an untyped proxy that forwards any method call to the remote actor.""" + return ActorProxy(self._ref, method_names=None, async_methods=None) + @property def ref(self) -> ActorRef: """Get underlying ActorRef.""" @@ -337,7 +391,7 @@ async def _sync_call(self, *args, **kwargs) -> Any: else: call_msg = _wrap_call_v1(self._method, args, kwargs, False) - resp = await self._ref.ask(call_msg) + resp = await _ask_convert_errors(self._ref, call_msg) if isinstance(resp, dict): result, error = _unwrap_response(resp) @@ -397,7 +451,7 @@ async def _get_stream(self): call_msg = _wrap_call_v2(self._method, self._args, self._kwargs, True) else: call_msg = _wrap_call_v1(self._method, self._args, self._kwargs, True) - resp = await self._ref.ask(call_msg) + resp = await _ask_convert_errors(self._ref, call_msg) # Response may be PyMessage (streaming) or direct Python object if isinstance(resp, Message): @@ -443,6 +497,11 @@ async def __anext__(self): ) if "__yield__" in item: return item["__yield__"] + # Single-value response (non-streaming): {"__result__": value} + if "__result__" in item: + self._final_result = item.get("__result__") + self._got_result = True + raise StopAsyncIteration return item except StopAsyncIteration: raise @@ -497,10 +556,7 @@ async def __anext__(self): self._got_result = True raise StopAsyncIteration if "__error__" in item: - # Actor execution error - raise PulsingActorError( - item["__error__"], actor_name=str(self._ref.actor_id.id) - ) + raise PulsingActorError(item["__error__"]) if "__yield__" in item: return item["__yield__"] return item @@ -508,6 +564,40 @@ async def __anext__(self): raise +class _DelayedCallProxy: + """Proxy returned by ``self.delayed(sec)`` — any method call becomes a delayed message to self. + + Usage inside a @remote class:: + + task = self.delayed(5.0).some_method(arg1, arg2) + task.cancel() # cancel if needed + + Returns an ``asyncio.Task`` that fires after the delay. + """ + + __slots__ = ("_ref", "_delay_sec") + + def __init__(self, ref: ActorRef, delay_sec: float): + self._ref = ref + self._delay_sec = delay_sec + + def __getattr__(self, name: str): + if name.startswith("_"): + raise AttributeError(name) + + def caller(*args, **kwargs): + msg = _wrap_call_v1(name, args, kwargs, is_async=True) + delay = max(0.0, self._delay_sec) + + async def _send(): + await asyncio.sleep(delay) + await self._ref.tell(msg) + + return asyncio.create_task(_send()) + + return caller + + class _WrappedActor(_ActorBase): """Wraps user class as an Actor""" @@ -534,6 +624,12 @@ def __original_file__(self): except (TypeError, OSError): return None + def _inject_delayed(self, actor_ref: ActorRef) -> None: + """Inject ``self.delayed(sec)`` on the user instance after spawn.""" + self._instance.delayed = lambda delay_sec: _DelayedCallProxy( + actor_ref, delay_sec + ) + def on_start(self, actor_id) -> None: if hasattr(self._instance, "on_start"): self._instance.on_start(actor_id) @@ -627,6 +723,25 @@ async def receive(self, msg) -> Any: return {"__error__": f"Unknown message type: {type(msg)}"} + @staticmethod + async def _safe_stream_write(writer, obj: dict) -> bool: + """Write to stream; return False if stream already closed (e.g. caller cancelled).""" + try: + await writer.write(obj) + return True + except (RuntimeError, OSError, ConnectionError) as e: + if "closed" in str(e).lower() or "stream" in str(e).lower(): + return False + raise + + @staticmethod + async def _safe_stream_close(writer) -> None: + """Close stream; ignore if already closed.""" + try: + await writer.close() + except (RuntimeError, OSError, ConnectionError): + pass + def _handle_generator_result(self, gen) -> StreamMessage: """Handle generator result, return streaming response""" stream_msg, writer = StreamMessage.create("GeneratorStream") @@ -635,17 +750,26 @@ async def execute(): try: if inspect.isasyncgen(gen): async for item in gen: - await writer.write({"__yield__": item}) + if not await self._safe_stream_write( + writer, {"__yield__": item} + ): + return else: for item in gen: - await writer.write({"__yield__": item}) - await writer.write({"__final__": True, "__result__": None}) + if not await self._safe_stream_write( + writer, {"__yield__": item} + ): + return + await self._safe_stream_write( + writer, {"__final__": True, "__result__": None} + ) except Exception as e: - await writer.write({"__error__": str(e)}) + await self._safe_stream_write(writer, {"__error__": str(e)}) finally: - await writer.close() + await self._safe_stream_close(writer) - asyncio.create_task(execute()) + task = asyncio.create_task(execute()) + task.add_done_callback(_consume_task_exception) return stream_msg def _handle_async_method(self, func, args, kwargs) -> StreamMessage: @@ -658,29 +782,39 @@ async def execute(): # Check result type if inspect.isasyncgen(result): - # Async generator async for item in result: - await writer.write({"__yield__": item}) - await writer.write({"__final__": True, "__result__": None}) + if not await self._safe_stream_write( + writer, {"__yield__": item} + ): + return + await self._safe_stream_write( + writer, {"__final__": True, "__result__": None} + ) elif asyncio.iscoroutine(result): - # Regular async function final_result = await result - await writer.write({"__final__": True, "__result__": final_result}) + await self._safe_stream_write( + writer, {"__final__": True, "__result__": final_result} + ) elif inspect.isgenerator(result): - # Synchronous generator for item in result: - await writer.write({"__yield__": item}) - await writer.write({"__final__": True, "__result__": None}) + if not await self._safe_stream_write( + writer, {"__yield__": item} + ): + return + await self._safe_stream_write( + writer, {"__final__": True, "__result__": None} + ) else: - # Regular return value - await writer.write({"__final__": True, "__result__": result}) + await self._safe_stream_write( + writer, {"__final__": True, "__result__": result} + ) except Exception as e: - await writer.write({"__error__": str(e)}) + await self._safe_stream_write(writer, {"__error__": str(e)}) finally: - await writer.close() + await self._safe_stream_close(writer) - # Execute in background task, non-blocking actor - asyncio.create_task(execute()) + task = asyncio.create_task(execute()) + task.add_done_callback(_consume_task_exception) return stream_msg @@ -897,10 +1031,13 @@ async def local( actor_name = f"actors/{self._cls.__name__}_{uuid.uuid4().hex[:8]}" if self._restart_policy != "never": + _wrapped_holder: list[_WrappedActor] = [] def factory(): instance = self._cls(*args, **kwargs) - return _WrappedActor(instance) + wrapped = _WrappedActor(instance) + _wrapped_holder.append(wrapped) + return wrapped actor_ref = await system.spawn( factory, @@ -911,10 +1048,13 @@ def factory(): min_backoff=self._min_backoff, max_backoff=self._max_backoff, ) + if _wrapped_holder: + _wrapped_holder[-1]._inject_delayed(actor_ref) else: instance = self._cls(*args, **kwargs) actor = _WrappedActor(instance) actor_ref = await system.spawn(actor, name=actor_name, public=public) + actor._inject_delayed(actor_ref) # Register actor metadata _register_actor_metadata(actor_name, self._cls) @@ -974,7 +1114,8 @@ async def remote( actor_name = f"actors/{self._cls.__name__}_{uuid.uuid4().hex[:8]}" # Send creation request - resp = await service_ref.ask( + resp = await _ask_convert_errors( + service_ref, Message.from_json( "CreateActor", { @@ -989,7 +1130,7 @@ async def remote( "min_backoff": self._min_backoff, "max_backoff": self._max_backoff, }, - ) + ), ) data = resp.to_json() @@ -1144,8 +1285,9 @@ def ref(self) -> ActorRef: async def _ask(self, msg_type: str) -> dict: """Send SystemMessage and return response.""" - resp = await self._ref.ask( - Message.from_json("SystemMessage", {"type": msg_type}) + resp = await _ask_convert_errors( + self._ref, + Message.from_json("SystemMessage", {"type": msg_type}), ) return resp.to_json() @@ -1221,7 +1363,9 @@ async def list_registry(self) -> list[str]: Returns: List of registered class names """ - resp = await self._ref.ask(Message.from_json("ListRegistry", {})) + resp = await _ask_convert_errors( + self._ref, Message.from_json("ListRegistry", {}) + ) data = resp.to_json() return data.get("classes", []) @@ -1256,7 +1400,8 @@ async def create_actor( Raises: RuntimeError: If creation fails """ - resp = await self._ref.ask( + resp = await _ask_convert_errors( + self._ref, Message.from_json( "CreateActor", { @@ -1270,7 +1415,7 @@ async def create_actor( "min_backoff": min_backoff, "max_backoff": max_backoff, }, - ) + ), ) data = resp.to_json() if resp.msg_type == "Error" or data.get("error"): @@ -1339,8 +1484,11 @@ async def resolve( name: str, *, node_id: int | None = None, -) -> ActorRef: - """Resolve a named actor by name, return ActorRef +): + """Resolve a named actor by name. + + Returns an object that supports .ask(), .tell(), and .as_any(). + Use .as_any() to get an untyped proxy that forwards any method call. For typed ActorProxy with method calls, use Counter.resolve(name) instead. @@ -1349,35 +1497,51 @@ async def resolve( node_id: Target node ID, searches in cluster if not provided Returns: - ActorRef: Low-level actor reference for ask/tell operations. + ActorRefView: Ref-like object with .as_any() for untyped proxy. Example: from pulsing.actor import init, remote, resolve await init() - @remote - class Counter: - def __init__(self, init=0): self.value = init - def increment(self): self.value += 1; return self.value - - # Create actor - counter = await Counter.spawn(name="my_counter") + # By name only (no type needed) + ref = await resolve("channel.discord") + proxy = ref.as_any() + await proxy.send_text(chat_id, content) - # Method 1: Use typed resolve (recommended) - proxy = await Counter.resolve("my_counter") - result = await proxy.increment() - - # Method 2: Use low-level resolve + ask + # Low-level ask ref = await resolve("my_counter") - result = await ref.ask({"method": "increment", "args": [], "kwargs": {}}) + result = await ref.ask({"__call__": "increment", "args": [], "kwargs": {}}) """ from . import _global_system if _global_system is None: raise RuntimeError("Actor system not initialized. Call 'await init()' first.") - return await _global_system.resolve(name, node_id=node_id) + try: + ref = await _global_system.resolve(name, node_id=node_id) + return ActorRefView(ref) + except RuntimeError as e: + raise _convert_rust_error(e) from e + + +def as_any(ref: ActorRef | ActorRefView) -> ActorProxy: + """Return an untyped proxy that forwards any method call to the remote actor. + + Use when you have an ActorRef (or ref from resolve()) and want to call + methods by name without the typed class. + + Args: + ref: ActorRef from resolve(name), or raw ActorRef from system.resolve_named(). + + Example: + ref = await resolve("channel.discord") + proxy = as_any(ref) # or proxy = ref.as_any() + await proxy.send_text(chat_id, content) + """ + if isinstance(ref, ActorRefView): + return ref.as_any() + return ActorProxy(ref, method_names=None, async_methods=None) RemoteClass = ActorClass diff --git a/python/pulsing/queue/README.md b/python/pulsing/queue/README.md index 02b64007f..a0d6f9f6d 100644 --- a/python/pulsing/queue/README.md +++ b/python/pulsing/queue/README.md @@ -42,7 +42,7 @@ │ bucket_0 │ │ bucket_1 │ │ bucket_2 │ │ │ │ │ │ │ │ - buffer[] │ │ - buffer[] │ │ - buffer[] │ -│ - Lance │ │ - Lance │ │ - Lance │ +│ - backend │ │ - backend │ │ - backend │ │ - Condition │ │ - Condition │ │ - Condition │ └──────────────┘ └──────────────┘ └──────────────┘ Node A Node B Node A @@ -68,7 +68,7 @@ 每个 bucket 一个实例,负责: - 数据缓冲(内存) -- 数据持久化(Lance) +- 数据持久化(由后端实现) - 消费者阻塞/唤醒(asyncio.Condition) ### 3. Queue / QueueWriter / QueueReader @@ -132,7 +132,7 @@ get_bucket_ref(system, topic, bucket_id) ┌─────────────────────────────────────────────────────┐ │ 总数据视图 │ ├─────────────────────────┬───────────────────────────┤ -│ 持久化 (Lance) │ 内存缓冲 │ +│ 持久化(若后端支持) │ 内存缓冲 │ │ [0, persisted_count) │ [persisted_count, total) │ └─────────────────────────┴───────────────────────────┘ ↑ @@ -140,7 +140,7 @@ get_bucket_ref(system, topic, bucket_id) ``` - 写入后数据**立即**对消费者可见(在内存缓冲中) -- 达到 `batch_size` 后自动持久化到 Lance +- 达到 `batch_size` 后由后端决定是否持久化 - 调用 `flush()` 可强制持久化 ## 快速开始 @@ -248,17 +248,9 @@ Consumer (rank=0) Consumer (rank=1) └─▶ bucket_2 └─▶ bucket_3 ``` -## 依赖 - -```bash -pip install lance pyarrow -``` - ---- - ## 可插拔存储后端 -队列系统支持可插拔的存储后端,可根据需求选择不同实现。 +队列仅内置 `memory` 后端;持久化等能力通过**插件**以 `register_backend()` 接入,不在 Pulsing 内直接依赖具体实现。 ### 内置后端 @@ -266,38 +258,20 @@ pip install lance pyarrow |------|------|----------| | `memory` | 纯内存,无持久化(默认) | 测试、临时数据 | -### 持久化后端(需安装 persisting) +### 插件后端 -```bash -pip install persisting[lance] -``` - -| 后端 | 说明 | 适用场景 | -|------|------|----------| -| `LanceBackend` | Lance 持久化 | 一般持久化场景 | -| `PersistingBackend` | 增强版(WAL、监控) | 生产环境 | - -### 使用方式 +持久化或其它后端由第三方包提供,通过 `register_backend()` 注册后使用: ```python -# 使用默认内存后端 +# 默认内存后端 writer = await system.queue.write("my_queue") -# 使用 persisting 的 Lance 持久化后端 -from persisting.queue import LanceBackend +# 使用插件提供的后端(示例) +from my_plugin import MyBackend from pulsing.queue import register_backend -register_backend("lance", LanceBackend) -writer = await system.queue.write("my_queue", backend="lance") - -# 使用增强版后端 -from persisting.queue import PersistingBackend -register_backend("persisting", PersistingBackend) -writer = await system.queue.write( - "my_queue", - backend="persisting", - backend_options={"enable_wal": True, "enable_metrics": True} -) +register_backend("my_backend", MyBackend) +writer = await system.queue.write("my_queue", backend="my_backend") ``` ### 自定义后端 @@ -364,11 +338,7 @@ class MyBackend: - 节点故障时,其 bucket 数据可能丢失(内存部分) - 可以考虑副本或 WAL 机制 -4. **Lance Schema 演化** - - 当前不支持 schema 变化 - - 新字段可能导致写入失败 - -5. **性能优化** +4. **性能优化** - `get_bucket_ref` 每次都查询 StorageManager - 可以增加客户端缓存,减少 RPC 调用 diff --git a/python/pulsing/queue/__init__.py b/python/pulsing/queue/__init__.py index e955836a1..0a4184b87 100644 --- a/python/pulsing/queue/__init__.py +++ b/python/pulsing/queue/__init__.py @@ -8,7 +8,7 @@ Storage Backends: - "memory": Pure in-memory backend (built-in default) -- Persistent backends require installing the persisting package +- Custom backends: register_backend() or pass class to write_queue() Example: system = await pul.actor_system() diff --git a/python/pulsing/queue/backend.py b/python/pulsing/queue/backend.py index e293227a4..74d88c5f5 100644 --- a/python/pulsing/queue/backend.py +++ b/python/pulsing/queue/backend.py @@ -1,21 +1,21 @@ """Storage Backend Protocol - Pluggable Storage Implementation -Defines StorageBackend protocol, allowing different storage implementations: -- MemoryBackend: Pure in-memory (built-in default) -- Third-party implementations: e.g., LanceBackend, PersistingBackend provided by persisting +Defines StorageBackend protocol for pluggable storage: +- MemoryBackend: Pure in-memory (built-in default, no extra deps) +- Custom backends: register via register_backend() or pass class to write_queue() Usage: - # Use built-in backend + # Built-in memory backend writer = await write_queue(system, "topic", backend="memory") - # Use persistent backend provided by persisting - from persisting.queue import LanceBackend + # Custom backend (e.g. from a plugin package) + from some_plugin import MyBackend from pulsing.queue import register_backend - register_backend("lance", LanceBackend) - writer = await write_queue(system, "topic", backend="lance") + register_backend("my_backend", MyBackend) + writer = await write_queue(system, "topic", backend="my_backend") # Or pass class directly - writer = await write_queue(system, "topic", backend=LanceBackend) + writer = await write_queue(system, "topic", backend=MyBackend) """ from __future__ import annotations @@ -85,9 +85,7 @@ class MemoryBackend: - Supports blocking wait for new data - Lightweight, suitable for testing and temporary data - For persistence capabilities, use backends provided by the persisting package: - - persisting.queue.LanceBackend: Lance persistence - - persisting.queue.PersistingBackend: Enhanced version (WAL, monitoring, etc.) + For persistence, use a plugin that implements StorageBackend (e.g. register_backend). """ def __init__(self, bucket_id: int, **kwargs): @@ -175,18 +173,17 @@ def total_count(self) -> int: "memory": MemoryBackend, } -# Third-party backend registration (e.g., lance provided by persisting) +# Plugin backends registered via register_backend() _REGISTERED_BACKENDS: dict[str, type] = {} def register_backend(name: str, backend_class: type) -> None: - """Register a custom backend + """Register a custom storage backend (e.g. from a plugin package). Example: - from persisting.queue import LanceBackend - register_backend("lance", LanceBackend) - - writer = await write_queue(system, "topic", backend="lance") + from my_plugin import MyBackend + register_backend("my_backend", MyBackend) + writer = await write_queue(system, "topic", backend="my_backend") """ if not isinstance(backend_class, type): raise TypeError(f"backend_class must be a class, got {type(backend_class)}") @@ -215,7 +212,7 @@ def get_backend_class(backend: str | type) -> type: available = list(_BUILTIN_BACKENDS.keys()) + list(_REGISTERED_BACKENDS.keys()) raise ValueError( f"Unknown backend: {backend}. Available: {available}. " - f"Use register_backend() to add custom backends, or install 'persisting' for Lance support." + "Use register_backend() to add custom backends." ) diff --git a/python/pulsing/queue/manager.py b/python/pulsing/queue/manager.py index 412d67dcb..bf1fc5894 100644 --- a/python/pulsing/queue/manager.py +++ b/python/pulsing/queue/manager.py @@ -92,7 +92,10 @@ def __init__( self._buckets: dict[tuple[str, int], ActorRef] = {} # Topic brokers managed by this node: {topic_name: ActorRef} self._topics: dict[str, ActorRef] = {} - self._lock = asyncio.Lock() + # Per-resource locks so different buckets/topics can be created in parallel + self._bucket_locks: dict[tuple[str, int], asyncio.Lock] = {} + self._topic_locks: dict[str, asyncio.Lock] = {} + self._locks_meta = asyncio.Lock() # Cached cluster member information self._members: list[dict] = [] @@ -131,32 +134,30 @@ async def _get_or_create_bucket( backend: str | type | None = None, backend_options: dict | None = None, ) -> ActorRef: - """Get or create local BucketStorage Actor""" + """Get or create local BucketStorage Actor. Per-key lock allows parallel creation.""" key = (topic, bucket_id) - if key in self._buckets: return self._buckets[key] - async with self._lock: + async with self._locks_meta: + if key not in self._bucket_locks: + self._bucket_locks[key] = asyncio.Lock() + lock = self._bucket_locks[key] + + async with lock: if key in self._buckets: return self._buckets[key] - - # Create BucketStorage Actor actor_name = f"bucket_{topic}_{bucket_id}" - # Use provided storage_path or default path if storage_path: bucket_storage_path = f"{storage_path}/bucket_{bucket_id}" else: bucket_storage_path = ( f"{self.base_storage_path}/{topic}/bucket_{bucket_id}" ) - try: - # Try to resolve existing self._buckets[key] = await self.system.resolve_named(actor_name) logger.debug(f"Resolved existing bucket: {actor_name}") except Exception: - # Create new using BucketStorage.local() for proper @remote wrapping proxy = await BucketStorage.local( self.system, bucket_id=bucket_id, @@ -169,33 +170,33 @@ async def _get_or_create_bucket( ) self._buckets[key] = proxy.ref logger.info(f"Created bucket: {actor_name} at {bucket_storage_path}") - return self._buckets[key] async def _get_or_create_topic_broker(self, topic_name: str) -> ActorRef: - """Get or create local TopicBroker Actor""" + """Get or create local TopicBroker Actor. Per-topic lock allows parallel creation.""" if topic_name in self._topics: return self._topics[topic_name] - async with self._lock: + async with self._locks_meta: + if topic_name not in self._topic_locks: + self._topic_locks[topic_name] = asyncio.Lock() + lock = self._topic_locks[topic_name] + + async with lock: if topic_name in self._topics: return self._topics[topic_name] - actor_name = f"_topic_broker_{topic_name}" try: self._topics[topic_name] = await self.system.resolve_named(actor_name) logger.debug(f"Resolved existing topic broker: {actor_name}") except Exception: - # Lazy import to avoid circular dependency from pulsing.topic.broker import TopicBroker - # Use TopicBroker.local() to create properly wrapped actor proxy = await TopicBroker.local( self.system, topic_name, self.system, name=actor_name, public=True ) self._topics[topic_name] = proxy.ref logger.info(f"Created topic broker: {actor_name}") - return self._topics[topic_name] # ========== Public Remote Methods ========== @@ -323,8 +324,19 @@ async def get_stats(self) -> dict: } -# Lock to prevent concurrent creation of StorageManager -_manager_lock = asyncio.Lock() +# Per-event-loop lock to prevent concurrent creation of StorageManager. +# Lazy init so the lock is bound to the current loop (avoids "bound to a different event loop" in tests). +_manager_lock: asyncio.Lock | None = None +_manager_lock_loop: asyncio.AbstractEventLoop | None = None + + +def _get_manager_lock() -> asyncio.Lock: + global _manager_lock, _manager_lock_loop + loop = asyncio.get_running_loop() + if _manager_lock is None or _manager_lock_loop is not loop: + _manager_lock = asyncio.Lock() + _manager_lock_loop = loop + return _manager_lock async def get_storage_manager(system: ActorSystem) -> "ActorProxy": @@ -343,7 +355,7 @@ async def get_storage_manager(system: ActorSystem) -> "ActorProxy": except Exception: pass - async with _manager_lock: + async with _get_manager_lock(): # Check local node again try: return await StorageManager.resolve( diff --git a/python/pulsing/queue/queue.py b/python/pulsing/queue/queue.py index 8f801bd01..d1756431b 100644 --- a/python/pulsing/queue/queue.py +++ b/python/pulsing/queue/queue.py @@ -33,8 +33,7 @@ class Queue: storage_path: Storage path backend: Storage backend - "memory": Pure in-memory backend (default) - - Persistent backend requires installing persisting package - - Custom class: Class implementing StorageBackend protocol + - Custom: register_backend() or class implementing StorageBackend backend_options: Additional backend parameters """ @@ -58,9 +57,10 @@ def __init__( self.backend = backend self.backend_options = backend_options - # Actor proxies for each bucket + # Actor proxies for each bucket; per-bucket locks allow parallel resolution self._bucket_refs: dict[int, ActorProxy] = {} - self._init_lock = asyncio.Lock() + self._bucket_locks: dict[int, asyncio.Lock] = {} + self._bucket_locks_meta = asyncio.Lock() # Save event loop reference (for sync wrapper) try: @@ -76,22 +76,21 @@ def _hash_partition(self, value: Any) -> int: return hash_value % self.num_buckets async def _ensure_bucket(self, bucket_id: int) -> ActorProxy: - """Ensure Actor for specified bucket is created + """Ensure Actor for specified bucket is created. - Get bucket reference through StorageManager: - 1. Send GetBucket request to local StorageManager - 2. StorageManager uses consistent hashing to determine owner - 3. If this node, create and return; otherwise return redirect - 4. Automatically handle redirects to get bucket on correct node + Uses per-bucket lock so different buckets can be resolved in parallel. """ if bucket_id in self._bucket_refs: return self._bucket_refs[bucket_id] - async with self._init_lock: + async with self._bucket_locks_meta: + if bucket_id not in self._bucket_locks: + self._bucket_locks[bucket_id] = asyncio.Lock() + lock = self._bucket_locks[bucket_id] + + async with lock: if bucket_id in self._bucket_refs: return self._bucket_refs[bucket_id] - - # Get bucket reference through StorageManager self._bucket_refs[bucket_id] = await get_bucket_ref( self.system, self.topic, @@ -327,19 +326,17 @@ async def write_queue( storage_path: Storage path backend: Storage backend - "memory": Pure in-memory backend (default) - - Persistent backend requires installing persisting package - - Custom class: Class implementing StorageBackend protocol + - Custom: register_backend() or pass StorageBackend class backend_options: Additional backend parameters Example: - # Use default in-memory backend writer = await write_queue(system, "my_queue") - # Use persisting's Lance backend - from persisting.queue import LanceBackend + # Custom backend from a plugin + from my_plugin import MyBackend from pulsing.queue import register_backend - register_backend("lance", LanceBackend) - writer = await write_queue(system, "my_queue", backend="lance") + register_backend("my_backend", MyBackend) + writer = await write_queue(system, "my_queue", backend="my_backend") """ # Ensure all nodes in cluster have StorageManager from .manager import ensure_storage_managers diff --git a/python/pulsing/queue/storage.py b/python/pulsing/queue/storage.py index 25caf1e75..d3e70f2c9 100644 --- a/python/pulsing/queue/storage.py +++ b/python/pulsing/queue/storage.py @@ -22,9 +22,8 @@ class BucketStorage: storage_path: Storage path batch_size: Batch size backend: Backend name or backend class - - "memory": Pure in-memory backend - - "lance": Lance persistent backend (default) - - Custom class: Class implementing StorageBackend protocol + - "memory": Pure in-memory backend (default) + - Custom name/class: Use register_backend() or pass class backend_options: Additional parameters passed to backend """ @@ -33,7 +32,7 @@ def __init__( bucket_id: int, storage_path: str, batch_size: int = 100, - backend: str | type = "lance", + backend: str | type = "memory", backend_options: dict[str, Any] | None = None, ): self.bucket_id = bucket_id diff --git a/python/pulsing/topic/broker.py b/python/pulsing/topic/broker.py index ffbf41a54..4210e7f32 100644 --- a/python/pulsing/topic/broker.py +++ b/python/pulsing/topic/broker.py @@ -327,6 +327,8 @@ async def _fanout_ask( logger.warning( f"TopicBroker[{self.topic}] wait_any_ack timeout after {timeout}s" ) + failed = len(tasks) + failed_ids = sub_ids.copy() for task in tasks: if not task.done(): task.cancel() diff --git a/python/pulsing/topic/topic.py b/python/pulsing/topic/topic.py index caeea056c..03774c79c 100644 --- a/python/pulsing/topic/topic.py +++ b/python/pulsing/topic/topic.py @@ -253,7 +253,7 @@ async def start(self) -> None: return if not self._callbacks: - logger.warning(f"TopicReader[{self._reader_id}] has no callbacks") + raise ValueError("at least one callback required") # Create subscriber Actor actor_name = f"_topic_sub_{self._topic}_{self._reader_id}" diff --git a/tests/python/test_queue.py b/tests/python/test_queue.py index 4433f4824..318fac679 100644 --- a/tests/python/test_queue.py +++ b/tests/python/test_queue.py @@ -9,7 +9,7 @@ - Distributed consumption (rank/world_size) - Stress tests (high concurrency, large data) -Note: Persistence tests (Lance backend) are in persisting package. +Note: Persistence tests live in plugin packages (e.g. persisting). """ import asyncio diff --git a/tests/python/test_queue_backends.py b/tests/python/test_queue_backends.py index 9d9b5ce85..67d20b69c 100644 --- a/tests/python/test_queue_backends.py +++ b/tests/python/test_queue_backends.py @@ -8,7 +8,7 @@ - Integration with write_queue/read_queue APIs - Custom backend implementation -Note: LanceBackend tests are in persisting package. +Note: Persistent backend tests live in plugin packages (e.g. persisting). """ import asyncio diff --git a/tests/python/test_queue_topic_chaos.py b/tests/python/test_queue_topic_chaos.py new file mode 100644 index 000000000..a47e5dbba --- /dev/null +++ b/tests/python/test_queue_topic_chaos.py @@ -0,0 +1,670 @@ +""" +Queue & Topic 混沌测试 + +在随机延迟、高并发、动态加入/退出、随机参数等混沌场景下验证: +- Queue: 数据不丢、不重(按 rank/world_size 分桶)、无死锁 +- Topic: 订阅者动态变化时发布不崩溃、交付语义可区分、慢/失败订阅者被踢或超时 +- 与 StorageManager 共享资源时无阻塞、无竞态 + +运行: pytest tests/python/test_queue_topic_chaos.py -v -s +""" + +from __future__ import annotations + +import asyncio +import random +import shutil +import tempfile +import time + +import pytest + +import pulsing as pul +from pulsing.queue import read_queue, write_queue +from pulsing.topic import PublishMode, read_topic, write_topic + + +# ============================================================================= +# Fixtures & 随机负载工具 +# ============================================================================= + + +@pytest.fixture +async def actor_system(): + system = await pul.actor_system() + yield system + await system.shutdown() + + +@pytest.fixture +def temp_storage_path(): + path = tempfile.mkdtemp(prefix="chaos_queue_") + yield path + shutil.rmtree(path, ignore_errors=True) + + +def _random_sleep(max_ms: int = 20): + """短随机延迟,模拟混沌.""" + return asyncio.sleep(random.uniform(0, max_ms) / 1000.0) + + +def _chaos_sleep( + min_ms: int = 0, max_ms: int = 50, occasional_long_ms: int | None = 120 +): + """随机延迟:常规 min~max_ms,小概率长延迟(模拟抖动)。""" + if occasional_long_ms and random.random() < 0.08: + return asyncio.sleep(random.uniform(max_ms, occasional_long_ms) / 1000.0) + return asyncio.sleep(random.uniform(min_ms, max_ms) / 1000.0) + + +# ============================================================================= +# Queue 混沌 +# ============================================================================= + + +@pytest.mark.asyncio +async def test_queue_chaos_concurrent_producer_consumer( + actor_system, temp_storage_path +): + """混沌:多生产者 + 多消费者(rank/world_size),随机 put/get/延迟,验证不丢不重.""" + random.seed(42) + topic = "chaos_q_concurrent" + num_buckets = random.choice([3, 4, 5, 6]) + world_size = random.choice([2, 3]) + num_producers = random.randint(2, 5) + messages_per_producer = random.randint(30, 70) + total_expected = num_producers * messages_per_producer + + produced_ids: set[str] = set() + produced_lock = asyncio.Lock() + + async def producer(pid: int): + writer = await write_queue( + actor_system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + for i in range(messages_per_producer): + rid = f"p{pid}_r{i}" + await writer.put({"id": rid, "producer": pid, "seq": i}) + async with produced_lock: + produced_ids.add(rid) + await _chaos_sleep(0, 15, 40) + if random.random() < 0.1: + await writer.flush() + await writer.flush() + + async def consumer(rank: int): + reader = await read_queue( + actor_system, + topic=topic, + rank=rank, + world_size=world_size, + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + seen: set[str] = set() + deadline = time.monotonic() + 20.0 + get_limit = random.randint(10, 40) + while time.monotonic() < deadline: + records = await reader.get(limit=get_limit, wait=True, timeout=1.2) + for r in records: + rid = r.get("id") + if rid: + assert rid not in seen, f"duplicate consumption: {rid}" + seen.add(rid) + await _chaos_sleep(2, 25, 50) + if len(seen) >= total_expected: + break + return rank, seen + + await asyncio.gather(*[producer(i) for i in range(num_producers)]) + + consumer_tasks = [asyncio.create_task(consumer(r)) for r in range(world_size)] + results = await asyncio.gather(*consumer_tasks) + + consumed_all: set[str] = set() + for _, seen in results: + consumed_all |= seen + + assert ( + len(consumed_all) == total_expected + ), f"expected {total_expected} unique ids, got {len(consumed_all)}; produced={len(produced_ids)}" + assert consumed_all == produced_ids, "consumed set != produced set" + + +@pytest.mark.asyncio +async def test_queue_chaos_many_buckets_parallel_handles( + actor_system, temp_storage_path +): + """混沌:多桶、多 writer 并行写,多 reader 并行读;用单 reader 收齐后校验总数(多 reader 会瓜分数据).""" + random.seed(43) + topic = "chaos_q_many_buckets" + num_buckets = random.randint(4, 12) + num_writers = random.randint(3, 6) + puts_per_writer = random.randint(20, 50) + expected_count = num_writers * puts_per_writer + + async def write_batch(wid: int): + w = await write_queue( + actor_system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + i = 0 + while i < puts_per_writer: + if random.random() < 0.2 and i + 1 < puts_per_writer: + await w.put( + [ + {"id": f"w{wid}_{i}", "v": i}, + {"id": f"w{wid}_{i+1}", "v": i + 1}, + ] + ) + i += 2 + else: + await w.put({"id": f"w{wid}_{i}", "v": i}) + i += 1 + await _chaos_sleep(0, 15, 45) + await w.flush() + return wid + + await asyncio.gather(*[write_batch(w) for w in range(num_writers)]) + + # 单 reader 读全量,避免多 reader 瓜分导致并集不足 expected_count + r = await read_queue( + actor_system, + topic=topic, + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + collected = [] + for _ in range(120): + limit = random.randint(5, 60) + batch = await r.get(limit=limit) + collected.extend(batch) + if len(collected) >= expected_count: + break + await _chaos_sleep(1, 20, None) + + all_ids = {rec.get("id") for rec in collected} + assert ( + len(all_ids) == expected_count + ), f"expected {expected_count} unique ids, got {len(all_ids)}" + + +@pytest.mark.asyncio +async def test_queue_chaos_reader_reset_and_reread(actor_system, temp_storage_path): + """混沌:同一 reader 多次 reset + get,与间歇写入交错,随机 limit/延迟.""" + random.seed(44) + topic = "chaos_q_reset" + num_buckets = random.choice([2, 3, 4]) + writer = await write_queue( + actor_system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + n_init = random.randint(15, 30) + for i in range(n_init): + await writer.put({"id": f"x{i}", "i": i}) + await _chaos_sleep(0, 5, None) + await writer.flush() + + reader = await read_queue( + actor_system, + topic=topic, + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + + limit_a = random.randint(5, 12) + first = await reader.get(limit=limit_a) + first_ids = {r["id"] for r in first} + reader.reset() + again = await reader.get(limit=limit_a) + again_ids = {r["id"] for r in again} + assert first_ids == again_ids, "reset then get should see same first batch" + + await writer.put({"id": "new1", "i": 100}) + await _chaos_sleep(2, 15, 30) + more = await reader.get(limit=random.randint(5, 15)) + ids_more = {r["id"] for r in more} + assert "new1" in ids_more or any(r.get("id") == "new1" for r in more) + + +# ============================================================================= +# Topic 混沌 +# ============================================================================= + + +@pytest.mark.asyncio +async def test_topic_chaos_subscribers_join_leave_during_publish(actor_system): + """混沌:发布过程中订阅者动态加入/退出,随机阶段数/每阶段消息数/模式/延迟.""" + random.seed(45) + topic_name = "chaos_t_join_leave" + writer = await write_topic(actor_system, topic_name) + num_phases = random.randint(4, 8) + messages_per_phase = random.randint(6, 18) + modes = [PublishMode.FIRE_AND_FORGET, PublishMode.BEST_EFFORT] + + received_by_id: dict[str, list] = {} + rec_lock = asyncio.Lock() + + async def make_reader(reader_id: str): + reader = await read_topic(actor_system, topic_name, reader_id=reader_id) + rec_list = [] + + async def on_msg(msg): + async with rec_lock: + rec_list.append(msg) + received_by_id[reader_id] = rec_list + + reader.add_callback(on_msg) + await reader.start() + return reader + + readers_alive: list[tuple[str, object]] = [] + + for phase in range(num_phases): + if random.random() < 0.35 and phase > 0 and readers_alive: + _, r = readers_alive.pop() + await r.stop() + rid = f"r{phase}" + reader = await make_reader(rid) + readers_alive.append((rid, reader)) + + for i in range(messages_per_phase): + mode = random.choice(modes) + await writer.publish( + {"phase": phase, "seq": i}, + mode=mode, + ) + await _chaos_sleep(1, 12, 35) + + await asyncio.sleep(0.4) + + for _, reader in readers_alive: + await reader.stop() + + for rid, rec_list in received_by_id.items(): + assert ( + len(rec_list) > 0 + ), f"reader {rid} should have received at least one message" + + +@pytest.mark.asyncio +async def test_topic_chaos_many_publishers_many_subscribers(actor_system): + """混沌:多发布者 + 多订阅者,随机发布模式/条数/延迟,验证每人收到预期条数.""" + random.seed(46) + topic_name = "chaos_t_many" + num_publishers = random.randint(3, 6) + num_subscribers = random.randint(2, 5) + messages_per_pub = random.randint(15, 40) + total_messages = num_publishers * messages_per_pub + + received: list[list] = [[] for _ in range(num_subscribers)] + locks = [asyncio.Lock() for _ in range(num_subscribers)] + readers = [] + + for i in range(num_subscribers): + reader = await read_topic(actor_system, topic_name, reader_id=f"sub_{i}") + + async def make_cb(idx): + async def cb(msg): + async with locks[idx]: + received[idx].append(msg) + + return cb + + reader.add_callback(await make_cb(i)) + await reader.start() + readers.append(reader) + + async def publish_batch(pid: int): + w = await write_topic(actor_system, topic_name, writer_id=f"pub_{pid}") + modes = [PublishMode.FIRE_AND_FORGET, PublishMode.BEST_EFFORT] + for j in range(messages_per_pub): + mode = random.choice(modes) + await w.publish({"pub": pid, "seq": j}, mode=mode) + await _chaos_sleep(0, 12, 30) + + await asyncio.gather(*[publish_batch(p) for p in range(num_publishers)]) + + await asyncio.sleep(0.6) + + for i in range(num_subscribers): + assert ( + len(received[i]) == total_messages + ), f"subscriber {i} expected {total_messages}, got {len(received[i])}" + + for r in readers: + await r.stop() + + +@pytest.mark.asyncio +async def test_topic_chaos_slow_callback_best_effort(actor_system): + """混沌:部分订阅者 callback 很慢,随机条数/延迟/超时,best_effort 验证不崩溃.""" + random.seed(47) + topic_name = "chaos_t_slow" + writer = await write_topic(actor_system, topic_name) + num_messages = random.randint(12, 25) + + fast_recv = [] + reader_fast = await read_topic(actor_system, topic_name, reader_id="fast") + reader_fast.add_callback(lambda m: fast_recv.append(m)) + await reader_fast.start() + + slow_recv = [] + reader_slow = await read_topic(actor_system, topic_name, reader_id="slow") + slow_delay = random.uniform(0.05, 0.15) + + async def slow_cb(m): + await asyncio.sleep(slow_delay) + slow_recv.append(m) + + reader_slow.add_callback(slow_cb) + await reader_slow.start() + + for i in range(num_messages): + await writer.publish( + {"seq": i}, + mode=PublishMode.BEST_EFFORT, + timeout=random.uniform(1.5, 3.0), + ) + await _chaos_sleep(2, 20, 50) + + await asyncio.sleep(0.5) + + assert ( + len(fast_recv) == num_messages + ), f"fast subscriber should get all {num_messages}, got {len(fast_recv)}" + await reader_fast.stop() + await reader_slow.stop() + + +# ============================================================================= +# 混合:Queue + Topic 同时混沌 +# ============================================================================= + + +@pytest.mark.asyncio +async def test_chaos_mixed_queue_and_topic_same_loop(actor_system, temp_storage_path): + """混沌:同一 loop 内 queue + topic 并发,随机条数/桶数/延迟.""" + random.seed(48) + q_topic = "chaos_mixed_q" + t_topic = "chaos_mixed_t" + num_buckets = random.randint(2, 6) + q_count = random.randint(25, 55) + t_count = random.randint(20, 45) + + q_done = asyncio.Event() + t_done = asyncio.Event() + + async def queue_chaos(): + w = await write_queue( + actor_system, + topic=q_topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + for i in range(q_count): + await w.put({"id": f"q{i}", "i": i}) + await _chaos_sleep(0, 10, 25) + await w.flush() + r = await read_queue( + actor_system, + topic=q_topic, + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + records = await r.get(limit=q_count + 10) + assert len(records) == q_count + q_done.set() + + async def topic_chaos(): + writer = await write_topic(actor_system, t_topic) + recv = [] + reader = await read_topic(actor_system, t_topic, reader_id="mixed_r") + reader.add_callback(lambda m: recv.append(m)) + await reader.start() + modes = [PublishMode.FIRE_AND_FORGET, PublishMode.BEST_EFFORT] + for i in range(t_count): + await writer.publish({"i": i}, mode=random.choice(modes)) + await _chaos_sleep(0, 8, 20) + await asyncio.sleep(0.25) + assert len(recv) == t_count + await reader.stop() + t_done.set() + + await asyncio.gather(queue_chaos(), topic_chaos()) + assert q_done.is_set() and t_done.is_set() + + +@pytest.mark.asyncio +async def test_chaos_rapid_open_close_handles(actor_system, temp_storage_path): + """混沌:快速反复创建/丢弃 queue writer 和 topic reader,随机次数/延迟.""" + random.seed(49) + n_writes = random.randint(6, 12) + n_readers = random.randint(4, 10) + num_buckets = random.choice([2, 3, 4]) + + for _ in range(n_writes): + w = await write_queue( + actor_system, + topic="chaos_rapid_q", + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + await w.put({"id": f"x{random.randint(0,1000)}", "v": 1}) + await w.flush() + del w + await _chaos_sleep(1, 15, 40) + + for i in range(n_readers): + reader = await read_topic(actor_system, "chaos_rapid_t", reader_id=f"rapid_{i}") + reader.add_callback(lambda m: None) + await reader.start() + await _chaos_sleep(2, 12, None) + await reader.stop() + del reader + + writer = await write_topic(actor_system, "chaos_rapid_t") + result = await writer.publish({"test": True}, mode=PublishMode.FIRE_AND_FORGET) + assert result.subscriber_count >= 0 + + # ------------------------------------------------------------------------- + # 新增:高复杂度 / 随机负载风暴 + # ------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_queue_chaos_storm_random_params(actor_system, temp_storage_path): + """混沌风暴:全随机参数(桶数/消费者数/生产者数/条数/get limit/延迟),验证不丢不重.""" + random.seed(100) + topic = "chaos_q_storm" + num_buckets = random.randint(2, 8) + world_size = random.randint(2, 4) + num_producers = random.randint(2, 6) + messages_per_producer = random.randint(20, 55) + total_expected = num_producers * messages_per_producer + + produced_ids: set[str] = set() + plock = asyncio.Lock() + + async def producer(pid: int): + w = await write_queue( + actor_system, + topic=topic, + bucket_column="id", + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + i = 0 + while i < messages_per_producer: + if random.random() < 0.15 and i + 1 < messages_per_producer: + batch = [ + {"id": f"storm_p{pid}_{i}", "p": pid, "i": i}, + {"id": f"storm_p{pid}_{i+1}", "p": pid, "i": i + 1}, + ] + await w.put(batch) + async with plock: + produced_ids.add(batch[0]["id"]) + produced_ids.add(batch[1]["id"]) + i += 2 + else: + rid = f"storm_p{pid}_{i}" + await w.put({"id": rid, "p": pid, "i": i}) + async with plock: + produced_ids.add(rid) + i += 1 + await _chaos_sleep(0, 20, 60) + await w.flush() + + async def consumer(rank: int): + r = await read_queue( + actor_system, + topic=topic, + rank=rank, + world_size=world_size, + num_buckets=num_buckets, + storage_path=temp_storage_path, + ) + seen: set[str] = set() + deadline = time.monotonic() + 25.0 + while time.monotonic() < deadline: + limit = random.randint(8, 50) + records = await r.get(limit=limit, wait=True, timeout=1.5) + for rec in records: + rid = rec.get("id") + if rid: + assert rid not in seen, f"duplicate: {rid}" + seen.add(rid) + await _chaos_sleep(1, 30, 80) + if len(seen) >= total_expected: + break + return seen + + await asyncio.gather(*[producer(i) for i in range(num_producers)]) + + results = await asyncio.gather(*[consumer(r) for r in range(world_size)]) + consumed_all: set[str] = set() + for s in results: + consumed_all |= s + + assert ( + consumed_all == produced_ids + ), f"storm: produced {len(produced_ids)} vs consumed {len(consumed_all)}" + + +@pytest.mark.asyncio +async def test_topic_chaos_storm_random_params(actor_system): + """混沌风暴:全随机 topic 参数(发布者/订阅者数量、条数、模式、延迟),验证交付.""" + random.seed(101) + topic_name = "chaos_t_storm" + num_publishers = random.randint(2, 5) + num_subscribers = random.randint(2, 5) + messages_per_pub = random.randint(18, 45) + total_messages = num_publishers * messages_per_pub + + received: list[list] = [[] for _ in range(num_subscribers)] + locks = [asyncio.Lock() for _ in range(num_subscribers)] + readers = [] + + for i in range(num_subscribers): + reader = await read_topic(actor_system, topic_name, reader_id=f"storm_sub_{i}") + + async def make_cb(idx): + async def cb(msg): + async with locks[idx]: + received[idx].append(msg) + + return cb + + reader.add_callback(await make_cb(i)) + await reader.start() + readers.append(reader) + + async def pub(pid: int): + w = await write_topic(actor_system, topic_name, writer_id=f"storm_pub_{pid}") + modes = [PublishMode.FIRE_AND_FORGET, PublishMode.BEST_EFFORT] + for j in range(messages_per_pub): + await w.publish({"pub": pid, "seq": j}, mode=random.choice(modes)) + await _chaos_sleep(0, 15, 40) + + await asyncio.gather(*[pub(p) for p in range(num_publishers)]) + await asyncio.sleep(0.7) + + for i in range(num_subscribers): + assert ( + len(received[i]) == total_messages + ), f"storm sub {i}: expected {total_messages}, got {len(received[i])}" + for r in readers: + await r.stop() + + +@pytest.mark.asyncio +async def test_chaos_storm_multi_queue_multi_topic(actor_system, temp_storage_path): + """混沌风暴:多 queue + 多 topic 同时跑,各自随机负载,验证无死锁、数据一致.""" + random.seed(102) + q_topics = ["chaos_storm_q1", "chaos_storm_q2"] + t_topics = ["chaos_storm_t1", "chaos_storm_t2"] + + async def run_queue(qtopic: str): + nb = random.randint(2, 5) + n_msg = random.randint(20, 45) + w = await write_queue( + actor_system, + topic=qtopic, + bucket_column="id", + num_buckets=nb, + storage_path=temp_storage_path, + ) + for i in range(n_msg): + await w.put({"id": f"{qtopic}_{i}", "i": i}) + await _chaos_sleep(0, 12, 35) + await w.flush() + r = await read_queue( + actor_system, + topic=qtopic, + num_buckets=nb, + storage_path=temp_storage_path, + ) + recs = await r.get(limit=n_msg + 20) + assert len(recs) == n_msg + return len(recs) + + async def run_topic(ttopic: str): + n_msg = random.randint(15, 35) + recv = [] + writer = await write_topic(actor_system, ttopic) + reader = await read_topic(actor_system, ttopic, reader_id=f"storm_{ttopic}") + reader.add_callback(lambda m: recv.append(m)) + await reader.start() + for i in range(n_msg): + await writer.publish( + {"i": i}, + mode=random.choice( + [PublishMode.FIRE_AND_FORGET, PublishMode.BEST_EFFORT] + ), + ) + await _chaos_sleep(0, 10, 30) + await asyncio.sleep(0.2) + assert len(recv) == n_msg + await reader.stop() + return len(recv) + + results = await asyncio.gather( + run_queue(q_topics[0]), + run_queue(q_topics[1]), + run_topic(t_topics[0]), + run_topic(t_topics[1]), + ) + assert results[0] == results[0] # sanity + assert len(results) == 4 diff --git a/tests/python/test_receive_error_behavior.py b/tests/python/test_receive_error_behavior.py new file mode 100644 index 000000000..7415cb5c1 --- /dev/null +++ b/tests/python/test_receive_error_behavior.py @@ -0,0 +1,71 @@ +""" +Tests for receive error behavior (业务错误不杀 actor、panic 停止不恢复). + +Covers: +1. receive 返回/抛出错误时:错误返回给调用者,actor 不退出,可继续处理下一条消息 +2. 多次 receive 错误:每次错误只回传调用方,actor 始终存活 +""" + +import pytest + +import pulsing as pul +from pulsing.actor import Actor + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +async def system(): + """Create a standalone ActorSystem for testing.""" + sys = await pul.actor_system() + yield sys + await sys.shutdown() + + +# ============================================================================ +# Actor: 对特定消息返回错误,其它消息正常处理 +# ============================================================================ + + +class ErrorOnBadMessageActor(Actor): + """收到 'bad' 时 raise,其它消息 echo.""" + + async def receive(self, msg): + if msg == "bad": + raise ValueError("intentional receive error") + return msg + + +# ============================================================================ +# Test: receive 出错只回传调用者,actor 不退出 +# ============================================================================ + + +@pytest.mark.asyncio +async def test_receive_error_returned_to_caller_actor_stays_alive(system): + """receive 返回/抛出错误时:调用者收到错误,actor 不退出,下一条消息正常处理。""" + ref = await system.spawn(ErrorOnBadMessageActor(), name="error_on_bad") + + # 第一条:触发错误,应收到异常 + with pytest.raises(Exception): + await ref.ask("bad") + + # 第二条:actor 仍存活,应正常返回 + result = await ref.ask("ok") + assert result == "ok" + + +@pytest.mark.asyncio +async def test_receive_multiple_errors_then_success(system): + """多次 receive 出错:每次错误只回传调用方,actor 始终存活,最后一条正常。""" + ref = await system.spawn(ErrorOnBadMessageActor(), name="multi_error") + + for _ in range(3): + with pytest.raises(Exception): + await ref.ask("bad") + + result = await ref.ask("ok") + assert result == "ok" diff --git a/tests/python/test_remote_decorator.py b/tests/python/test_remote_decorator.py index f5be18b43..58100a8ea 100644 --- a/tests/python/test_remote_decorator.py +++ b/tests/python/test_remote_decorator.py @@ -8,6 +8,7 @@ This file covers advanced features not in the apis tests: - ActorProxy.from_ref with method validation - Error handling in methods +- Delayed call: self.delayed(sec).method(...) - Concurrent async method behavior """ @@ -188,6 +189,87 @@ async def async_method(self): await shutdown() +# ============================================================================ +# Delayed Call Tests (self.delayed(sec).method(...)) +# ============================================================================ + + +@pytest.mark.asyncio +async def test_remote_delayed_call(): + """Test self.delayed(sec).method(...) schedules a tell after delay.""" + from pulsing.actor import init, shutdown, remote + + @remote + class DelayedCallService: + def __init__(self): + self.received: list[str] = [] + + def trigger_delayed(self): + """Schedule a delayed call to record(); returns immediately.""" + self.delayed(0.05).record("delayed_ok") + return "scheduled" + + def record(self, msg: str): + self.received.append(msg) + + def get_received(self): + return list(self.received) + + await init() + + try: + service = await DelayedCallService.spawn() + + out = await service.trigger_delayed() + assert out == "scheduled" + + # Delayed call not yet delivered + assert await service.get_received() == [] + + await asyncio.sleep(0.1) + + assert await service.get_received() == ["delayed_ok"] + + finally: + await shutdown() + + +@pytest.mark.asyncio +async def test_remote_delayed_call_cancel(): + """Test that the task returned by delayed().method() can be cancelled.""" + from pulsing.actor import init, shutdown, remote + + @remote + class DelayedCancelService: + def __init__(self): + self.received: list[str] = [] + + def schedule_then_cancel(self): + task = self.delayed(1.0).record("should_not_appear") + task.cancel() + return "cancelled" + + def record(self, msg: str): + self.received.append(msg) + + def get_received(self): + return list(self.received) + + await init() + + try: + service = await DelayedCancelService.spawn() + + out = await service.schedule_then_cancel() + assert out == "cancelled" + + await asyncio.sleep(0.2) + assert await service.get_received() == [] + + finally: + await shutdown() + + # ============================================================================ # Async Method Concurrency Tests # ============================================================================ diff --git a/tests/python/test_resolve_as_any.py b/tests/python/test_resolve_as_any.py new file mode 100644 index 000000000..858ae4d3b --- /dev/null +++ b/tests/python/test_resolve_as_any.py @@ -0,0 +1,236 @@ +""" +Tests for resolve().as_any() and as_any(ref): untyped proxy that forwards any method call. + +Covers: +- resolve(name) returns an object with .as_any() +- ref.as_any() returns a proxy; await proxy.method(...) works without knowing the actor type +- as_any(ref) function works with ref from resolve() or raw ActorRef +- typed_proxy.as_any() returns an any proxy with the same underlying ref +- ref.ask() / ref.tell() still work (backward compatibility) +""" + +import asyncio + +import pytest + +import pulsing as pul +from pulsing.actor import Actor, ActorRefView, as_any, remote + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +@pytest.fixture +async def initialized_pul(): + """Initialize global pulsing system for testing.""" + await pul.init() + yield + await pul.shutdown() + + +# ============================================================================ +# Test: resolve() returns object with .as_any() +# ============================================================================ + + +@pytest.mark.asyncio +async def test_resolve_returns_ref_view_with_as_any(initialized_pul): + """resolve(name) returns an object that has .as_any() method.""" + await pul.spawn( + _EchoActor(), + name="as_any_echo", + public=True, + ) + + ref = await pul.resolve("as_any_echo") + assert ref is not None + assert hasattr(ref, "as_any") + assert callable(getattr(ref, "as_any")) + + proxy = ref.as_any() + assert proxy is not None + assert hasattr(proxy, "ref") + + +@pytest.mark.asyncio +async def test_resolve_returns_actor_ref_view(initialized_pul): + """resolve(name) returns ActorRefView (or equivalent with .as_any()).""" + await pul.spawn(_EchoActor(), name="ref_view_echo", public=True) + + ref = await pul.resolve("ref_view_echo") + assert isinstance(ref, ActorRefView) + + +# ============================================================================ +# Test: ref.as_any() proxy forwards any method call +# ============================================================================ + + +class _EchoActor(Actor): + """Simple actor that echoes and has a named method for proxy calls.""" + + async def receive(self, msg): + if isinstance(msg, dict) and "echo" in msg: + return msg["echo"] + return msg + + +@pul.remote +class _ServiceWithMethods: + """Remote service with sync and async methods for as_any tests.""" + + def __init__(self): + self.value = 0 + + def get_value(self): + return self.value + + def set_value(self, n: int): + self.value = n + return self.value + + async def async_incr(self): + self.value += 1 + return self.value + + def echo(self, text: str): + return text + + +@pytest.mark.asyncio +async def test_as_any_proxy_calls_sync_method(initialized_pul): + """ref.as_any() returns a proxy; await proxy.sync_method() works.""" + await _ServiceWithMethods.spawn(name="as_any_svc", public=True) + + ref = await pul.resolve("as_any_svc") + proxy = ref.as_any() + + result = await proxy.get_value() + assert result == 0 + + result = await proxy.set_value(42) + assert result == 42 + + result = await proxy.get_value() + assert result == 42 + + +@pytest.mark.asyncio +async def test_as_any_proxy_calls_async_method(initialized_pul): + """await proxy.async_method() works through as_any() proxy.""" + await _ServiceWithMethods.spawn(name="as_any_async_svc", public=True) + + ref = await pul.resolve("as_any_async_svc") + proxy = ref.as_any() + + result = await proxy.async_incr() + assert result == 1 + result = await proxy.async_incr() + assert result == 2 + + +@pytest.mark.asyncio +async def test_as_any_proxy_method_with_args(initialized_pul): + """proxy.method(args, kwargs) forwards correctly.""" + await _ServiceWithMethods.spawn(name="as_any_echo_svc", public=True) + + ref = await pul.resolve("as_any_echo_svc") + proxy = ref.as_any() + + result = await proxy.echo("hello") + assert result == "hello" + + +# ============================================================================ +# Test: as_any(ref) function +# ============================================================================ + + +@pytest.mark.asyncio +async def test_as_any_function_with_ref_from_resolve(initialized_pul): + """as_any(ref) works when ref is from pul.resolve().""" + await _ServiceWithMethods.spawn(name="as_any_fn_svc", public=True) + + ref = await pul.resolve("as_any_fn_svc") + proxy = as_any(ref) + + result = await proxy.get_value() + assert result == 0 + + +@pytest.mark.asyncio +async def test_as_any_function_with_raw_ref(initialized_pul): + """as_any(ref) works when ref is raw ActorRef from system.resolve().""" + from pulsing.actor import get_system + + await _ServiceWithMethods.spawn(name="as_any_raw_svc", public=True) + + system = get_system() + raw_ref = await system.resolve("as_any_raw_svc") + proxy = as_any(raw_ref) + + result = await proxy.get_value() + assert result == 0 + + +# ============================================================================ +# Test: typed proxy.as_any() +# ============================================================================ + + +@pytest.mark.asyncio +async def test_typed_proxy_as_any(initialized_pul): + """typed_proxy.as_any() returns a proxy that can call the same methods.""" + await _ServiceWithMethods.spawn(name="typed_any_svc", public=True) + + typed = await _ServiceWithMethods.resolve("typed_any_svc") + result_typed = await typed.get_value() + assert result_typed == 0 + + any_proxy = typed.as_any() + result_any = await any_proxy.get_value() + assert result_any == 0 + + await any_proxy.set_value(100) + assert await typed.get_value() == 100 + + +# ============================================================================ +# Test: backward compatibility — ref.ask() / ref.tell() still work +# ============================================================================ + + +@pytest.mark.asyncio +async def test_resolve_ref_ask_still_works(initialized_pul): + """After resolve(), ref.ask(msg) still works (ActorRefView delegates to _ref).""" + await pul.spawn(_EchoActor(), name="compat_ask_echo", public=True) + + ref = await pul.resolve("compat_ask_echo") + result = await ref.ask({"echo": "hello"}) + assert result == "hello" + + +@pytest.mark.asyncio +async def test_resolve_ref_tell_still_works(initialized_pul): + """After resolve(), ref.tell(msg) still works.""" + + class _CountTell(Actor): + def __init__(self): + self.n = 0 + + async def receive(self, msg): + self.n += 1 + if msg == "get": + return self.n + return None + + await pul.spawn(_CountTell(), name="compat_tell_count", public=True) + + ref = await pul.resolve("compat_tell_count") + await ref.tell(None) + await ref.tell(None) + await asyncio.sleep(0.05) + result = await ref.ask("get") + assert result == 3 diff --git a/tests/python/test_topic.py b/tests/python/test_topic.py index 2871eaf2d..c21ce051b 100644 --- a/tests/python/test_topic.py +++ b/tests/python/test_topic.py @@ -653,19 +653,19 @@ async def producer(prod_id: int): @pytest.mark.asyncio async def test_read_topic_auto_start(actor_system): - """Test auto_start parameter.""" + """Test auto_start parameter: with callbacks, start() runs; without, ValueError.""" _writer = await write_topic(actor_system, "auto_start_topic") - received = [] + # auto_start=True with no callbacks must raise + with pytest.raises(ValueError, match="at least one callback required"): + await read_topic(actor_system, "auto_start_topic", auto_start=True) - # No callbacks before start - should warn but work - reader = await read_topic(actor_system, "auto_start_topic", auto_start=True) + # Normal: add_callback then start (or use auto_start=False and start() later) + received = [] + reader = await read_topic(actor_system, "auto_start_topic", auto_start=False) reader.add_callback(lambda m: received.append(m)) - - # Since auto_start=True but no callbacks at creation time, - # the reader started without callbacks + await reader.start() assert reader.is_started - await reader.stop()