From 03e69bb456df12784b5c507bc06d2c0747270f40 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Sat, 18 Apr 2026 14:19:15 -0700 Subject: [PATCH 1/4] support lease extension --- LEASE_EXTENSION.md | 134 +++++++ README.md | 23 ++ WORKER_CONFIGURATION.md | 6 +- docs/WORKER.md | 2 +- docs/design/lease-extension.md | 333 ++++++++++++++++++ examples/lease_extension_example.py | 163 +++++++++ .../client/automator/async_task_runner.py | 94 +++++ src/conductor/client/automator/task_runner.py | 95 +++++ tests/integration/test_lease_extension.py | 234 ++++++++++++ 9 files changed, 1080 insertions(+), 4 deletions(-) create mode 100644 LEASE_EXTENSION.md create mode 100644 docs/design/lease-extension.md create mode 100644 examples/lease_extension_example.py create mode 100644 tests/integration/test_lease_extension.py diff --git a/LEASE_EXTENSION.md b/LEASE_EXTENSION.md new file mode 100644 index 000000000..7949bb441 --- /dev/null +++ b/LEASE_EXTENSION.md @@ -0,0 +1,134 @@ +# Lease Extension (Automatic Heartbeat) + +When a worker picks up a task, the Conductor server starts a `responseTimeoutSeconds` timer. If the worker doesn't send an update before the timer expires, the server marks the task as timed out and re-queues it for retry. + +For long-running tasks (agent tool calls, LLM inference, data processing, batch jobs), the worker is actively executing but the server thinks it's dead. **Lease extension** solves this by automatically sending heartbeats that reset the timeout timer. + +## How It Works + +When `lease_extend_enabled=True`: + +1. Worker picks up a task with `responseTimeoutSeconds > 0` +2. SDK starts tracking the task for heartbeats +3. At **80% of `responseTimeoutSeconds`**, SDK sends a heartbeat (`TaskResult.extend_lease=True`) +4. Server resets the task's `updateTime` to now, giving a fresh `responseTimeoutSeconds` window +5. Heartbeats continue until the task completes, fails, or the worker shuts down + +``` +Timeline (responseTimeoutSeconds=120s): + 0s 96s 192s 288s + |-----------|-----------|-----------|--→ task completes + poll heartbeat heartbeat heartbeat + (80%) (80%) (80%) +``` + +The heartbeat fires at 80% of `responseTimeoutSeconds` (matching the Java SDK). This gives a 20% safety margin — if a heartbeat is slightly delayed, the task still has time before the server times it out. + +## Quick Start + +```python +from conductor.client.worker.worker_task import worker_task + +@worker_task( + task_definition_name='long_running_analysis', + lease_extend_enabled=True, # Enable automatic heartbeat +) +def analyze_dataset(dataset_id: str) -> dict: + """This task takes 5 minutes but responseTimeoutSeconds is 60s. + Heartbeats keep it alive automatically.""" + results = run_expensive_analysis(dataset_id) + return {'results': results} +``` + +That's it. The SDK handles heartbeats automatically in the background. + +## Enabling Lease Extension + +Lease extension is **disabled by default** (matching the Java SDK). Enable it per-worker or globally: + +### Per-Worker (Decorator) + +```python +@worker_task( + task_definition_name='my_task', + lease_extend_enabled=True, +) +def my_task(data: str) -> dict: + ... +``` + +### Per-Worker (Class) + +```python +from conductor.client.worker.worker import Worker + +worker = Worker( + task_definition_name='my_task', + execute_function=my_function, + lease_extend_enabled=True, +) +``` + +### Per-Worker (Environment Variable) + +```shell +export conductor_worker_my_task_lease_extend_enabled=true +``` + +### Global (All Workers) + +```shell +export conductor_worker_all_lease_extend_enabled=true +``` + +### Precedence + +Environment variables override decorator/constructor arguments: + +1. Task-specific env var (`conductor_worker__lease_extend_enabled`) +2. Global env var (`conductor_worker_all_lease_extend_enabled`) +3. Worker constructor / decorator argument + +## When to Use + +**Enable lease extension when:** +- Task execution time may exceed `responseTimeoutSeconds` +- Tasks involve external calls with unpredictable latency (LLM APIs, data pipelines) +- You want the worker to hold the task continuously (not yield and re-poll) + +**You don't need lease extension when:** +- Tasks always complete within `responseTimeoutSeconds` +- You're using `TaskInProgress` with `callbackAfterSeconds` (the task is yielded back to the queue) +- `responseTimeoutSeconds` is 0 (no timeout configured) + +## Lease Extension vs TaskInProgress + +These are two different strategies for long-running tasks: + +| | Lease Extension | TaskInProgress | +|---|---|---| +| **How it works** | Worker holds the task, heartbeats keep it alive | Worker yields the task, re-polls later | +| **Task state** | IN_PROGRESS the whole time | Returned to queue between polls | +| **When to use** | Continuous execution (LLM calls, streaming) | Incremental processing (batch chunks, polling external status) | +| **Enable with** | `lease_extend_enabled=True` | Return `TaskInProgress(callback_after_seconds=N)` | +| **Worker memory** | Task stays in worker memory | Task is released, re-polled with fresh context | + +You can combine both — enable `lease_extend_enabled` for safety while also using `TaskInProgress` for incremental polling. + +## Important Constraints + +- **`responseTimeoutSeconds`** is the time between updates. This is what heartbeats reset. +- **`timeoutSeconds`** is the overall SLA wall-clock ceiling. **Cannot be extended by heartbeat.** Once exceeded, the task is TIMED_OUT regardless of heartbeats. +- Heartbeats only fire when `responseTimeoutSeconds > 0` and `lease_extend_enabled = True`. +- If the heartbeat interval would be less than 1 second (i.e., `responseTimeoutSeconds < 1.25`), heartbeats are skipped. + +## Retry on Failure + +If a heartbeat API call fails, the SDK retries up to 3 times with backoff (`1s`, `1.5s`, `2s`). If all retries fail, the error is logged and the SDK tries again on the next poll loop iteration. If the network is truly partitioned, the server will eventually time out the task — this is correct behavior. + +## Example + +See [examples/lease_extension_example.py](examples/lease_extension_example.py) for a complete runnable example that: +- Defines a long-running worker with `lease_extend_enabled=True` +- Creates a workflow with a short `responseTimeoutSeconds` +- Runs the workflow and proves the task completes despite sleeping longer than the timeout diff --git a/README.md b/README.md index 4fe864ce5..955352089 100644 --- a/README.md +++ b/README.md @@ -19,6 +19,7 @@ If you find [Conductor](https://github.com/conductor-oss/conductor) useful, plea * [Workers: Sync and Async](#workers-sync-and-async) * [Workflows with HTTP Calls and Waits](#workflows-with-http-calls-and-waits) * [Long-Running Tasks with TaskContext](#long-running-tasks-with-taskcontext) + * [Lease Extension for Long-Running Tasks](#lease-extension-for-long-running-tasks) * [Monitoring with Metrics](#monitoring-with-metrics) * [Managing Workflow Executions](#managing-workflow-executions) * [AI & LLM Workflows](#ai--llm-workflows) @@ -275,6 +276,26 @@ def batch_job(batch_id: str) -> Union[dict, TaskInProgress]: See [examples/task_context_example.py](examples/task_context_example.py) for all patterns (polling, retry-aware logic, async context, input access). +### Lease Extension for Long-Running Tasks + +For tasks that run longer than `responseTimeoutSeconds` (e.g., LLM inference, data pipelines, batch jobs), enable automatic lease extension. The SDK sends heartbeats at 80% of `responseTimeoutSeconds`, resetting the server's timeout timer so the task stays alive: + +```python +from conductor.client.worker.worker_task import worker_task + +@worker_task( + task_definition_name='train_model', + lease_extend_enabled=True, # Automatic heartbeat — keeps task alive +) +def train_model(dataset_id: str) -> dict: + """Runs for 10 minutes, but responseTimeoutSeconds is only 60s. + Heartbeats at 48s intervals keep the lease alive.""" + model = train(dataset_id) + return {'model_id': model.id, 'accuracy': model.accuracy} +``` + +Disabled by default. Enable per-worker via decorator, constructor, or environment variable (`conductor_worker__lease_extend_enabled=true`). See [LEASE_EXTENSION.md](LEASE_EXTENSION.md) for the full guide. + ### Monitoring with Metrics Enable Prometheus metrics with a single setting — the SDK exposes poll counts, execution times, error rates, and HTTP latency: @@ -407,6 +428,7 @@ See the [Examples Guide](examples/README.md) for the full catalog. Key examples: | [kitchensink.py](examples/kitchensink.py) | All task types (HTTP, JS, JQ, Switch) | `python examples/kitchensink.py` | | [workflow_ops.py](examples/workflow_ops.py) | Pause, resume, terminate, retry, restart, rerun, signal | `python examples/workflow_ops.py` | | [task_context_example.py](examples/task_context_example.py) | Long-running tasks with TaskInProgress | `python examples/task_context_example.py` | +| [lease_extension_example.py](examples/lease_extension_example.py) | Automatic heartbeat for long-running tasks | `python examples/lease_extension_example.py` | | [metrics_example.py](examples/metrics_example.py) | Prometheus metrics collection | `python examples/metrics_example.py` | | [fastapi_worker_service.py](examples/fastapi_worker_service.py) | FastAPI: expose a workflow as an API (+ workers) | `uvicorn examples.fastapi_worker_service:app --port 8081 --workers 1` | | [helloworld.py](examples/helloworld/helloworld.py) | Minimal hello world | `python examples/helloworld/helloworld.py` | @@ -431,6 +453,7 @@ End-to-end examples covering all APIs for each domain: | [Worker Design](docs/design/WORKER_DESIGN.md) | Architecture: AsyncTaskRunner vs TaskRunner, discovery, lifecycle | | [Worker Guide](docs/WORKER.md) | All worker patterns (function, class, annotation, async) | | [Worker Configuration](WORKER_CONFIGURATION.md) | Hierarchical environment variable configuration | +| [Lease Extension](LEASE_EXTENSION.md) | Automatic heartbeat for long-running tasks | | [Workflow Management](docs/WORKFLOW.md) | Start, pause, resume, terminate, retry, search | | [Workflow Testing](docs/WORKFLOW_TESTING.md) | Unit testing with mock outputs | | [Task Management](docs/TASK_MANAGEMENT.md) | Task operations | diff --git a/WORKER_CONFIGURATION.md b/WORKER_CONFIGURATION.md index d52590f1d..9d8ab1a01 100644 --- a/WORKER_CONFIGURATION.md +++ b/WORKER_CONFIGURATION.md @@ -29,12 +29,12 @@ The following properties can be configured via environment variables: | `overwrite_task_def` | bool | Overwrite existing task definitions when registering (default: true) | `false` | ✅ Yes | | `strict_schema` | bool | Enforce strict schema validation - additionalProperties=false (default: false) | `true` | ✅ Yes | | `poll_timeout` | int | Poll request timeout in milliseconds | `100` | ✅ Yes | -| `lease_extend_enabled` | bool | ⚠️ **Not implemented** - reserved for future use | `false` | ✅ Yes | +| `lease_extend_enabled` | bool | Auto-extend task lease via heartbeat (see below) | `false` | ✅ Yes | | `paused` | bool | Pause worker from polling/executing tasks | `true` | ❌ **Environment-only** | **Notes**: - The `paused` property is intentionally **not available** in the `@worker_task` decorator. It can only be controlled via environment variables, allowing operators to pause/resume workers at runtime without code changes or redeployment. -- The `lease_extend_enabled` parameter is accepted but **not currently implemented**. For lease extension, use manual `TaskInProgress` returns (see below). +- When `lease_extend_enabled=True`, the SDK automatically sends heartbeats at 80% of `responseTimeoutSeconds` to keep long-running tasks alive. Without it, tasks that exceed `responseTimeoutSeconds` are timed out and retried by the server. - The `register_task_def` parameter automatically registers task definitions with JSON Schema (draft-07) generated from Python type hints. - The `overwrite_task_def` parameter controls whether to overwrite existing task definitions (default: true). - The `strict_schema` parameter controls JSON schema validation strictness (default: false for lenient validation). @@ -97,7 +97,7 @@ def long_task(job_id: str) -> Union[dict, TaskInProgress]: return {'status': 'completed', 'result': processed} ``` -**⚠️ Note**: The `lease_extend_enabled=True` configuration parameter does **not** provide automatic lease extension. You must explicitly return `TaskInProgress` to extend the lease. +**Automatic lease extension**: When `lease_extend_enabled=True`, the SDK sends a heartbeat to the server at 80% of `responseTimeoutSeconds`, resetting the timeout clock. This keeps the task alive without requiring manual `TaskInProgress` returns. The `TaskInProgress` pattern is still useful for chunked/checkpoint-based execution where the worker yields the task back to the queue. **For detailed patterns**, see [Long-Running Tasks & Lease Extension](docs/design/WORKER_DESIGN.md#long-running-tasks--lease-extension). diff --git a/docs/WORKER.md b/docs/WORKER.md index 372d3e480..848e9ccf3 100644 --- a/docs/WORKER.md +++ b/docs/WORKER.md @@ -593,7 +593,7 @@ class SimpleCppWorker(WorkerInterface): ## Long-Running Tasks and Lease Extension -For tasks that take longer than the configured `responseTimeoutSeconds`, the SDK provides automatic lease extension to prevent timeouts. See the comprehensive [Lease Extension Guide](../../LEASE_EXTENSION.md) for: +For tasks that take longer than the configured `responseTimeoutSeconds`, the SDK provides automatic lease extension to prevent timeouts. See the comprehensive [Lease Extension Guide](../LEASE_EXTENSION.md) for: - How lease extension works - Automatic vs manual control diff --git a/docs/design/lease-extension.md b/docs/design/lease-extension.md new file mode 100644 index 000000000..8ca6daec7 --- /dev/null +++ b/docs/design/lease-extension.md @@ -0,0 +1,333 @@ +# Design: Task Lease Extension (Heartbeat) + +## Problem + +When a worker picks up a task, the Conductor server starts a `responseTimeoutSeconds` timer. If the worker doesn't send an update before the timer expires, the server marks the task as timed out and re-queues it for retry. + +This is a problem for long-running tasks (e.g., agent tool calls, LLM inference, data processing). The worker is actively executing, but the server thinks it's dead. + +Today, the only workaround in the Python SDK is `TaskInProgress` with `callbackAfterSeconds` — which yields the task back to the queue. That doesn't work for continuous execution where the worker must hold the task. + +The Java SDK solves this with automatic lease extension: a background thread periodically sends a heartbeat (`extendLease=true`) to reset the `responseTimeoutSeconds` timer. The Python SDK has the `extend_lease` field on `TaskResult` and `lease_extend_enabled` on Worker, but **no background heartbeat loop** — the docs say "Not implemented — reserved for future use." + +## Goal + +Implement automatic lease extension in the Python SDK, matching the Java SDK's semantics: + +- Heartbeat at 80% of `responseTimeoutSeconds` +- Only when `responseTimeoutSeconds > 0` and `lease_extend_enabled = True` +- Retry on failure (3 attempts) +- Automatic stop on task completion +- Graceful shutdown cleanup +- **Disabled by default** — opt-in per-worker or globally via config/env var + +## Java SDK Reference + +The Java SDK implementation in `TaskRunner.java`: + +```java +// Constants +LEASE_EXTEND_RETRY_COUNT = 3 +LEASE_EXTEND_DURATION_FACTOR = 0.8 + +// Disabled by default — Worker interface: +default boolean leaseExtendEnabled() { + return PropertyFactory.getBoolean(getTaskDefName(), PROP_LEASE_EXTEND_ENABLED, false); +} + +// When task is polled and worker has leaseExtendEnabled: +if (task.getResponseTimeoutSeconds() > 0 && worker.leaseExtendEnabled()) { + long delay = Math.round(task.getResponseTimeoutSeconds() * 0.8); + leaseExtendFuture = leaseExtendExecutorService.scheduleWithFixedDelay( + extendLease(task, taskFuture), delay, delay, TimeUnit.SECONDS + ); + leaseExtendMap.put(task.getTaskId(), leaseExtendFuture); +} + +// Cancellation — in processTask() finally block: +cancelLeaseExtension(task.getTaskId()) +``` + +Key properties: +- **Disabled by default** — must be explicitly enabled per worker +- **Only when `responseTimeoutSeconds > 0`** — no heartbeat if there's no timeout +- **Fires at 80%** of `responseTimeoutSeconds` — e.g., 120s timeout → heartbeat at 96s +- **Separate single-threaded executor** — heartbeats never block task execution +- **Retry** — 3 attempts with `500ms * (count+1)` backoff +- **Always cancelled in finally** — whether task succeeds, fails, or throws + +## Server-Side Behavior (No Changes Needed) + +The server already fully supports lease extension. No server changes required. + +### Flow + +1. Worker sends `POST /tasks` with `TaskResult.extendLease = true` +2. `WorkflowExecutorOps.updateTask()` checks `isExtendLease()` → short-circuits: + ```java + if (taskResult.isExtendLease()) { + extendLease(taskResult); // resets updateTime only + return null; // no further task processing + } + ``` +3. `ExecutionDAOFacade.extendLease()` updates **only** `task.updateTime`: + ```java + public void extendLease(TaskModel taskModel) { + taskModel.setUpdateTime(System.currentTimeMillis()); + executionDAO.updateTask(taskModel); + } + ``` +4. All other fields in the `TaskResult` are ignored + +### Timeout Check + +The server's `DeciderService.isResponseTimedOut()` runs during workflow evaluation: + +``` +Task times out when: + (now - task.updateTime) > (responseTimeoutSeconds + callbackAfterSeconds) * 1000 +``` + +Each heartbeat resets `updateTime` to now → fresh `responseTimeoutSeconds` window. + +### Validations + +- Task must exist (404 if not) +- Task must NOT be in terminal state — if already COMPLETED/FAILED/TIMED_OUT, heartbeat is silently ignored (handles race conditions) + +### Response + +- `POST /tasks` → returns task ID (plain text) +- `POST /tasks/update-v2` → returns `204 No Content` (no next-task polling on heartbeat) + +### Important Constraints + +- **`responseTimeoutSeconds`** — time between updates before timeout. **This is what heartbeat resets.** +- **`timeoutSeconds`** — overall SLA wall-clock ceiling. **Cannot be extended by heartbeat.** Once exceeded, task is TIMED_OUT regardless. + +## Design + +### Approach: Inline Heartbeat in the Poll Loop + +Instead of a dedicated background thread (Java approach), piggyback heartbeats on the existing `run_once()` poll loop that already cycles continuously in each worker process. + +``` +TaskRunner.run_once() loop (already exists): + 1. Check completed async tasks + 2. Cleanup finished futures + 3. ← NEW: Check in-flight tasks, send heartbeats if due + 4. Batch poll for new tasks + 5. Submit tasks to thread pool +``` + +**Why this over background threads:** +- No extra threads — uses the existing poll loop infrastructure +- No Timer/ScheduledExecutor complexity, no cancellation logic +- Heartbeat state is naturally cleaned up when tasks complete +- Simpler shutdown — no separate executor to drain +- The poll loop runs frequently enough (adaptive backoff resets when tasks are active) + +**Trade-off:** Heartbeat timing is approximate (depends on poll loop frequency), but it doesn't need to be precise — we just need to fire before 100% of `responseTimeoutSeconds`. + +### Tracking State + +```python +@dataclass +class _LeaseInfo: + """Tracks when a heartbeat is next due for an in-flight task.""" + task_id: str + workflow_instance_id: str + response_timeout_seconds: float + last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) + interval_seconds: float # 80% of responseTimeoutSeconds +``` + +On `TaskRunner.__init__()`: +```python +self._lease_info: dict[str, _LeaseInfo] = {} # task_id → _LeaseInfo +``` + +### Core Methods + +```python +LEASE_EXTEND_RETRY_COUNT = 3 +LEASE_EXTEND_DURATION_FACTOR = 0.8 + +def _track_lease(self, task: Task) -> None: + """Start tracking a task for heartbeat. Called when task begins execution.""" + if not self.worker.lease_extend_enabled: + return + timeout = task.response_timeout_seconds + if not timeout or timeout <= 0: + return + interval = timeout * LEASE_EXTEND_DURATION_FACTOR + if interval < 1: + return + self._lease_info[task.task_id] = _LeaseInfo( + task_id=task.task_id, + workflow_instance_id=task.workflow_instance_id, + response_timeout_seconds=timeout, + last_heartbeat_time=time.monotonic(), + interval_seconds=interval, + ) + +def _untrack_lease(self, task_id: str) -> None: + """Stop tracking a task. Called when task completes or fails.""" + self._lease_info.pop(task_id, None) + +def _send_due_heartbeats(self) -> None: + """Check all tracked tasks and send heartbeats for any that are due. + Called at the top of each run_once() iteration.""" + if not self._lease_info: + return + now = time.monotonic() + for info in list(self._lease_info.values()): + elapsed = now - info.last_heartbeat_time + if elapsed < info.interval_seconds: + continue + # Heartbeat is due + self._send_heartbeat(info) + info.last_heartbeat_time = time.monotonic() + +def _send_heartbeat(self, info: _LeaseInfo) -> None: + """Send a single lease extension heartbeat with retry.""" + result = TaskResult( + task_id=info.task_id, + workflow_instance_id=info.workflow_instance_id, + extend_lease=True, + ) + for attempt in range(LEASE_EXTEND_RETRY_COUNT): + try: + self.task_client.update_task(body=result) + logger.debug("Extended lease for task %s", info.task_id) + return + except Exception: + if attempt < LEASE_EXTEND_RETRY_COUNT - 1: + time.sleep(0.5 * (attempt + 2)) + else: + logger.error( + "Failed to extend lease for task %s after %d attempts", + info.task_id, LEASE_EXTEND_RETRY_COUNT, + ) +``` + +### Integration Points + +#### TaskRunner.run_once() + +```python +def run_once(self) -> None: + self.__check_completed_async_tasks() + self.__cleanup_completed_tasks() + self._send_due_heartbeats() # ← NEW: send heartbeats before polling + # ... existing polling and task submission logic ... +``` + +#### Task Execution Tracking + +In `__execute_and_update_task()`: + +```python +def __execute_and_update_task(self, task: Task) -> None: + self._track_lease(task) # ← NEW + try: + while task is not None and not self._shutdown: + task_result = self.__execute_task(task) + if task_result is None or isinstance(task_result, TaskInProgress): + return + self._untrack_lease(task.task_id) # ← NEW: done with this task + task = self.__update_task(task_result) + if task is not None: + self._track_lease(task) # ← NEW: v2 returned next task + finally: + if task is not None: + self._untrack_lease(task.task_id) # ← NEW: always cleanup +``` + +#### AsyncTaskRunner.run_once() + +Same pattern — `_send_due_heartbeats_async()` at the top of `run_once()`, using `await` for the API call: + +```python +async def run_once(self) -> None: + await self._send_due_heartbeats_async() # ← NEW + # ... existing async polling and task submission logic ... +``` + +#### Cleanup on Shutdown + +```python +def _cleanup(self) -> None: + self._lease_info.clear() # ← NEW: drop all tracking + # ... existing shutdown logic ... +``` + +### Configuration + +No new configuration surface. The existing plumbing already works end-to-end: + +| Layer | How to enable | Default | +|-------|--------------|---------| +| `@worker_task` decorator | `lease_extend_enabled=True` | `False` | +| `Worker` class | `Worker(..., lease_extend_enabled=True)` | `False` | +| Environment variable | `conductor.worker..lease_extend_enabled=true` | `False` | +| Global env override | `conductor.worker.all.lease_extend_enabled=true` | `False` | +| `TaskRunner` | `__set_worker_properties()` resolves and applies | `False` | + +Matches Java SDK: **disabled by default**, opt-in per worker or globally. + +### Constants + +Match Java SDK: + +```python +LEASE_EXTEND_RETRY_COUNT = 3 # retries per heartbeat attempt +LEASE_EXTEND_DURATION_FACTOR = 0.8 # heartbeat at 80% of responseTimeoutSeconds +``` + +## Edge Cases + +| Scenario | Behavior | +|----------|----------| +| `lease_extend_enabled = False` (default) | No tracking, no heartbeats — zero overhead for existing users | +| `responseTimeoutSeconds = 0` | No tracking (no timeout to extend) | +| `responseTimeoutSeconds` very small (< 1.25s) | `interval = x * 0.8` — if < 1s, skip (too small to heartbeat meaningfully) | +| Poll loop slower than heartbeat interval | Heartbeat fires on next `run_once()` — slightly late but still within `responseTimeoutSeconds` since we fire at 80% | +| Task completes between heartbeat checks | `_untrack_lease()` removes it; no stale heartbeat sent | +| Heartbeat fails 3 times | Log error, keep tracking — next `run_once()` will retry. Task may timeout server-side if network-partitioned (correct behavior) | +| Worker process crashes | Tracking dict dies with process — server times out task after `responseTimeoutSeconds`, re-queues | +| `_shutdown` set | `run_once()` loop stops → no more heartbeats; `_cleanup()` clears tracking | +| v2 endpoint returns next task | Old task untracked, new task tracked with fresh timestamp | +| Multiple tasks in-flight (thread_count > 1) | Each tracked independently in `_lease_info` dict; `_send_due_heartbeats()` iterates all | + +## Files to Change + +1. **`src/conductor/client/automator/task_runner.py`** + - Add `_LeaseInfo` dataclass and constants + - Add `_lease_info` dict to `__init__()` + - Add `_track_lease()`, `_untrack_lease()`, `_send_due_heartbeats()`, `_send_heartbeat()` + - Wire into `run_once()` and `__execute_and_update_task()` + - Add `_lease_info.clear()` to `_cleanup()` + +2. **`src/conductor/client/automator/async_task_runner.py`** + - Same additions with async variants + - `_send_due_heartbeats_async()` uses `await` for API calls + +3. **`WORKER_CONFIGURATION.md`** + - Remove "Not implemented — reserved for future use" warning + - Document heartbeat behavior: 80% interval, retry logic, when to enable + +4. **Tests** + - Verify heartbeat sent when `lease_extend_enabled=True` and `responseTimeoutSeconds > 0` + - Verify NO heartbeat when `lease_extend_enabled=False` + - Verify NO heartbeat when `responseTimeoutSeconds = 0` + - Verify `_untrack_lease()` on task completion prevents further heartbeats + - Verify retry logic on API failure + - Integration test: long-running task with short `responseTimeoutSeconds` completes without timeout + +## Non-Goals + +- **Extending `timeoutSeconds`** — the overall SLA is a hard ceiling, unaffected by lease extension +- **Server-side changes** — server already supports `extendLease` fully +- **Dedicated API endpoint** — uses existing `POST /tasks` with `extendLease` flag (matches Java SDK) +- **Configurable heartbeat factor** — hardcoded to 0.8 (matches Java SDK; can be made configurable later if needed) diff --git a/examples/lease_extension_example.py b/examples/lease_extension_example.py new file mode 100644 index 000000000..ab63e49b8 --- /dev/null +++ b/examples/lease_extension_example.py @@ -0,0 +1,163 @@ +""" +Lease Extension (Automatic Heartbeat) Example +============================================== + +Demonstrates how lease extension keeps a long-running task alive +even when its execution time exceeds responseTimeoutSeconds. + +How it works: +- The task has responseTimeoutSeconds=30 (server times it out after 30s of inactivity) +- The worker sleeps for 60s (well beyond the timeout) +- With lease_extend_enabled=True, the SDK automatically sends heartbeats at 80% of + responseTimeoutSeconds (every 24s), resetting the server's timeout timer +- The task completes successfully despite running 2x longer than the timeout + +Without lease extension, the server would mark the task as TIMED_OUT after 30s. + +Run: + export CONDUCTOR_SERVER_URL="http://localhost:8080/api" + python examples/lease_extension_example.py +""" + +import logging +import time + +from conductor.client.automator.task_handler import TaskHandler +from conductor.client.configuration.configuration import Configuration +from conductor.client.http.models.task_def import TaskDef +from conductor.client.http.models.workflow_def import WorkflowDef +from conductor.client.http.models.workflow_task import WorkflowTask +from conductor.client.http.models.start_workflow_request import StartWorkflowRequest +from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient +from conductor.client.orkes.orkes_workflow_client import OrkesWorkflowClient +from conductor.client.worker.worker_task import worker_task + +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', +) +logger = logging.getLogger(__name__) + +# Task timeout configuration +RESPONSE_TIMEOUT_SECONDS = 30 # Server times out after 30s of inactivity +TASK_SLEEP_SECONDS = 60 # Worker sleeps 60s (2x the timeout) + +WORKFLOW_NAME = 'lease_extension_demo' +TASK_NAME = 'lease_heartbeat_demo_task' + + +# --------------------------------------------------------------------------- +# Worker with lease extension enabled +# --------------------------------------------------------------------------- + +@worker_task( + task_definition_name=TASK_NAME, + lease_extend_enabled=True, # Heartbeats keep the lease alive + register_task_def=True, + task_def=TaskDef( + name=TASK_NAME, + response_timeout_seconds=RESPONSE_TIMEOUT_SECONDS, + timeout_seconds=300, # Overall SLA: 5 minutes + retry_count=0, + ), + overwrite_task_def=True, +) +def lease_heartbeat_demo_task(job_id: str) -> dict: + """ + Long-running task that sleeps longer than responseTimeoutSeconds. + + Without lease extension, this would time out after 30s. + With lease extension, the SDK sends heartbeats at 24s intervals (80% of 30s), + keeping the task alive until completion. + """ + logger.info( + "Starting job %s — sleeping %ds (responseTimeout=%ds, heartbeat every %ds)", + job_id, TASK_SLEEP_SECONDS, RESPONSE_TIMEOUT_SECONDS, + int(RESPONSE_TIMEOUT_SECONDS * 0.8), + ) + time.sleep(TASK_SLEEP_SECONDS) + logger.info("Completed job %s", job_id) + return { + 'job_id': job_id, + 'status': 'completed', + 'slept_seconds': TASK_SLEEP_SECONDS, + 'response_timeout_seconds': RESPONSE_TIMEOUT_SECONDS, + } + + +# --------------------------------------------------------------------------- +# Workflow setup and execution +# --------------------------------------------------------------------------- + +def register_workflow(metadata_client: OrkesMetadataClient): + """Register a single-task workflow for the demo.""" + workflow = WorkflowDef(name=WORKFLOW_NAME, version=1) + task = WorkflowTask( + name=TASK_NAME, + task_reference_name=f'{TASK_NAME}_ref', + input_parameters={'job_id': '${workflow.input.job_id}'}, + ) + workflow._tasks = [task] + try: + metadata_client.update_workflow_def(workflow, overwrite=True) + except Exception: + metadata_client.register_workflow_def(workflow, overwrite=True) + logger.info("Registered workflow: %s", WORKFLOW_NAME) + + +def wait_for_workflow(workflow_client: OrkesWorkflowClient, wf_id: str, timeout: int = 120): + """Poll until the workflow reaches a terminal state.""" + for _ in range(timeout): + wf = workflow_client.get_workflow(wf_id, include_tasks=True) + if wf.status in ('COMPLETED', 'FAILED', 'TIMED_OUT', 'TERMINATED'): + return wf + time.sleep(1) + return workflow_client.get_workflow(wf_id, include_tasks=True) + + +def main(): + config = Configuration() + metadata_client = OrkesMetadataClient(config) + workflow_client = OrkesWorkflowClient(config) + + # Register the workflow definition + register_workflow(metadata_client) + + # Start workers (auto-discovers @worker_task functions) + with TaskHandler(configuration=config, scan_for_annotated_workers=True) as handler: + handler.start_processes() + time.sleep(2) # Let workers initialize + + # Start the workflow + req = StartWorkflowRequest() + req.name = WORKFLOW_NAME + req.version = 1 + req.input = {'job_id': 'DEMO-001'} + wf_id = workflow_client.start_workflow(start_workflow_request=req) + + print() + print("=" * 70) + print(f" Workflow started: {wf_id}") + print(f" Task sleeps {TASK_SLEEP_SECONDS}s with responseTimeout={RESPONSE_TIMEOUT_SECONDS}s") + print(f" Heartbeat interval: {int(RESPONSE_TIMEOUT_SECONDS * 0.8)}s (80% of timeout)") + print(f" UI: {config.ui_host}/execution/{wf_id}") + print("=" * 70) + print() + + # Wait for completion + wf = wait_for_workflow(workflow_client, wf_id, timeout=TASK_SLEEP_SECONDS + 30) + + print(f" Final status: {wf.status}") + for task in (wf.tasks or []): + print(f" Task '{task.task_def_name}': {task.status}") + if task.output_data: + print(f" Output: {task.output_data}") + + if wf.status == 'COMPLETED': + print("\n SUCCESS: Task completed with lease extension keeping it alive!") + else: + print(f"\n UNEXPECTED: Workflow ended with status {wf.status}") + + +if __name__ == '__main__': + main() diff --git a/src/conductor/client/automator/async_task_runner.py b/src/conductor/client/automator/async_task_runner.py index c0848523d..da8b81f4f 100644 --- a/src/conductor/client/automator/async_task_runner.py +++ b/src/conductor/client/automator/async_task_runner.py @@ -5,6 +5,7 @@ import sys import time import traceback +from dataclasses import dataclass from conductor.client.configuration.configuration import Configuration from conductor.client.configuration.settings.metrics_settings import MetricsSettings @@ -39,6 +40,20 @@ ) ) +# Lease extension constants (matches Java SDK) +LEASE_EXTEND_RETRY_COUNT = 3 +LEASE_EXTEND_DURATION_FACTOR = 0.8 + + +@dataclass +class _LeaseInfo: + """Tracks when a heartbeat is next due for an in-flight task.""" + task_id: str + workflow_instance_id: str + response_timeout_seconds: float + last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) + interval_seconds: float # 80% of responseTimeoutSeconds + class AsyncTaskRunner: """ @@ -112,6 +127,7 @@ def __init__( self._semaphore = None self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint + self._lease_info = {} # task_id -> _LeaseInfo for lease extension heartbeats async def run(self) -> None: """Main async loop - runs continuously in single event loop.""" @@ -166,6 +182,9 @@ async def _cleanup(self) -> None: """Clean up async resources.""" logger.debug("Cleaning up AsyncTaskRunner resources...") + # Stop all lease extension tracking + self._lease_info.clear() + # Cancel any running tasks (EAFP style) try: for task in list(self._running_tasks): @@ -423,6 +442,9 @@ async def __async_register_task_definition(self) -> None: async def run_once(self) -> None: """Execute one iteration of the polling loop (async version).""" try: + # Send lease extension heartbeats for any tasks that are due + await self._send_due_heartbeats() + # No need for manual cleanup - tasks remove themselves via add_done_callback # Just check capacity directly current_capacity = len(self._running_tasks) @@ -573,6 +595,7 @@ async def __async_execute_and_update_task(self, task: Task) -> None: # Acquire semaphore for entire task lifecycle (execution + update) # This ensures we never exceed thread_count tasks in any stage of processing async with self._semaphore: + self._track_lease(task) try: while task is not None and not self._shutdown: task_result = await self.__async_execute_task(task) @@ -582,6 +605,7 @@ async def __async_execute_and_update_task(self, task: Task) -> None: return if task_result is None: return + self._untrack_lease(task.task_id) # Update task and get next task from v2 response task = await self.__async_update_task(task_result) # v2 returns the next task; if v1 was used (returns None), immediately @@ -589,12 +613,17 @@ async def __async_execute_and_update_task(self, task: Task) -> None: if task is None and not self._use_update_v2 and not self._shutdown: tasks = await self.__async_batch_poll(1) task = tasks[0] if tasks else None + if task is not None: + self._track_lease(task) except Exception as e: logger.error( "Error executing/updating task %s: %s", task.task_id if task else "unknown", traceback.format_exc() ) + finally: + if task is not None: + self._untrack_lease(task.task_id) async def __async_execute_task(self, task: Task) -> TaskResult: """Execute async worker function directly (no threads, no BackgroundEventLoop).""" @@ -908,6 +937,71 @@ async def __async_update_task(self, task_result: TaskResult): return None + # -- Lease extension (heartbeat) methods ---------------------------------- + + def _track_lease(self, task) -> None: + """Start tracking a task for lease extension heartbeat.""" + if not getattr(self.worker, 'lease_extend_enabled', False): + return + timeout = getattr(task, 'response_timeout_seconds', None) or 0 + if timeout <= 0: + return + interval = timeout * LEASE_EXTEND_DURATION_FACTOR + if interval < 1: + return + self._lease_info[task.task_id] = _LeaseInfo( + task_id=task.task_id, + workflow_instance_id=task.workflow_instance_id, + response_timeout_seconds=timeout, + last_heartbeat_time=time.monotonic(), + interval_seconds=interval, + ) + logger.debug( + "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", + task.task_id, timeout, interval, + ) + + def _untrack_lease(self, task_id: str) -> None: + """Stop tracking a task for lease extension.""" + removed = self._lease_info.pop(task_id, None) + if removed is not None: + logger.debug("Untracked lease for task %s", task_id) + + async def _send_due_heartbeats(self) -> None: + """Check all tracked tasks and send heartbeats for any that are due.""" + if not self._lease_info: + return + now = time.monotonic() + for info in list(self._lease_info.values()): + elapsed = now - info.last_heartbeat_time + if elapsed < info.interval_seconds: + continue + await self._send_heartbeat(info) + info.last_heartbeat_time = time.monotonic() + + async def _send_heartbeat(self, info: _LeaseInfo) -> None: + """Send a single lease extension heartbeat with retry (async).""" + result = TaskResult( + task_id=info.task_id, + workflow_instance_id=info.workflow_instance_id, + extend_lease=True, + ) + for attempt in range(LEASE_EXTEND_RETRY_COUNT): + try: + await self.async_task_client.update_task(body=result) + logger.debug("Extended lease for task %s", info.task_id) + return + except Exception: + if attempt < LEASE_EXTEND_RETRY_COUNT - 1: + await asyncio.sleep(0.5 * (attempt + 2)) + else: + logger.error( + "Failed to extend lease for task %s after %d attempts", + info.task_id, LEASE_EXTEND_RETRY_COUNT, + ) + + # -------------------------------------------------------------------------- + def __set_worker_properties(self) -> None: """ Resolve worker configuration using hierarchical override (same as TaskRunner). diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 1541976ad..2949d5afe 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -5,6 +5,7 @@ import time import traceback from concurrent.futures import ThreadPoolExecutor, as_completed +from dataclasses import dataclass, field from typing import List, Optional, Any from conductor.client.configuration.configuration import Configuration @@ -41,6 +42,20 @@ ) ) +# Lease extension constants (matches Java SDK) +LEASE_EXTEND_RETRY_COUNT = 3 +LEASE_EXTEND_DURATION_FACTOR = 0.8 + + +@dataclass +class _LeaseInfo: + """Tracks when a heartbeat is next due for an in-flight task.""" + task_id: str + workflow_instance_id: str + response_timeout_seconds: float + last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) + interval_seconds: float # 80% of responseTimeoutSeconds + class TaskRunner: def __init__( @@ -93,6 +108,7 @@ def __init__( self._consecutive_empty_polls = 0 # Track empty polls to implement backoff self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint + self._lease_info = {} # task_id -> _LeaseInfo for lease extension heartbeats def run(self) -> None: if self.configuration is not None: @@ -132,6 +148,9 @@ def _cleanup(self) -> None: """Clean up resources - called on exit.""" logger.debug("Cleaning up TaskRunner resources...") + # Stop all lease extension tracking + self._lease_info.clear() + # Shutdown ThreadPoolExecutor (EAFP style - more Pythonic) try: self._executor.shutdown(wait=True, cancel_futures=True) @@ -391,6 +410,9 @@ def __register_task_definition(self) -> None: def run_once(self) -> None: try: + # Send lease extension heartbeats for any tasks that are due + self._send_due_heartbeats() + # Check completed async tasks first (non-blocking) self.__check_completed_async_tasks() @@ -511,6 +533,7 @@ def __execute_and_update_task(self, task: Task) -> None: The loop breaks when no next task is available, the task is async/in-progress, or shutdown is requested. """ + self._track_lease(task) try: while task is not None and not self._shutdown: task_result = self.__execute_task(task) @@ -522,6 +545,7 @@ def __execute_and_update_task(self, task: Task) -> None: if isinstance(task_result, TaskInProgress): logger.debug("Task %s is in progress, will update when complete", task.task_id) return + self._untrack_lease(task.task_id) # Update task and get next task from v2 response task = self.__update_task(task_result) # v2 returns the next task; if v1 was used (returns None), immediately @@ -529,12 +553,17 @@ def __execute_and_update_task(self, task: Task) -> None: if task is None and not self._use_update_v2 and not self._shutdown: tasks = self.__batch_poll_tasks(1) task = tasks[0] if tasks else None + if task is not None: + self._track_lease(task) except Exception as e: logger.error( "Error executing/updating task %s: %s", task.task_id if task else "unknown", traceback.format_exc() ) + finally: + if task is not None: + self._untrack_lease(task.task_id) def __batch_poll_tasks(self, count: int) -> list: """Poll for multiple tasks at once (more efficient than polling one at a time)""" @@ -938,6 +967,72 @@ def __update_task(self, task_result: TaskResult): return None + # -- Lease extension (heartbeat) methods ---------------------------------- + + def _track_lease(self, task: Task) -> None: + """Start tracking a task for lease extension heartbeat.""" + lease_enabled = getattr(self.worker, 'lease_extend_enabled', False) + if not lease_enabled: + return + timeout = getattr(task, 'response_timeout_seconds', None) or 0 + if timeout <= 0: + return + interval = timeout * LEASE_EXTEND_DURATION_FACTOR + if interval < 1: + return + self._lease_info[task.task_id] = _LeaseInfo( + task_id=task.task_id, + workflow_instance_id=task.workflow_instance_id, + response_timeout_seconds=timeout, + last_heartbeat_time=time.monotonic(), + interval_seconds=interval, + ) + logger.debug( + "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", + task.task_id, timeout, interval, + ) + + def _untrack_lease(self, task_id: str) -> None: + """Stop tracking a task for lease extension.""" + removed = self._lease_info.pop(task_id, None) + if removed is not None: + logger.debug("Untracked lease for task %s", task_id) + + def _send_due_heartbeats(self) -> None: + """Check all tracked tasks and send heartbeats for any that are due.""" + if not self._lease_info: + return + now = time.monotonic() + for info in list(self._lease_info.values()): + elapsed = now - info.last_heartbeat_time + if elapsed < info.interval_seconds: + continue + self._send_heartbeat(info) + info.last_heartbeat_time = time.monotonic() + + def _send_heartbeat(self, info: _LeaseInfo) -> None: + """Send a single lease extension heartbeat with retry.""" + result = TaskResult( + task_id=info.task_id, + workflow_instance_id=info.workflow_instance_id, + extend_lease=True, + ) + for attempt in range(LEASE_EXTEND_RETRY_COUNT): + try: + self.task_client.update_task(body=result) + logger.debug("Extended lease for task %s", info.task_id) + return + except Exception: + if attempt < LEASE_EXTEND_RETRY_COUNT - 1: + time.sleep(0.5 * (attempt + 2)) + else: + logger.error( + "Failed to extend lease for task %s after %d attempts", + info.task_id, LEASE_EXTEND_RETRY_COUNT, + ) + + # -------------------------------------------------------------------------- + def __wait_for_polling_interval(self) -> None: polling_interval = self.worker.get_polling_interval_in_seconds() time.sleep(polling_interval) diff --git a/tests/integration/test_lease_extension.py b/tests/integration/test_lease_extension.py new file mode 100644 index 000000000..0bb72a56c --- /dev/null +++ b/tests/integration/test_lease_extension.py @@ -0,0 +1,234 @@ +""" +E2E test for lease extension (heartbeat) feature. + +Proves that: +1. WITH lease extension enabled: a long-running task completes successfully + even when its execution time exceeds responseTimeoutSeconds, because + heartbeats keep the lease alive. + +2. WITHOUT lease extension: the same long-running task times out on the + server after responseTimeoutSeconds and is retried/failed. + +Run: + export CONDUCTOR_SERVER_URL="http://localhost:6767/api" + python3 -m pytest tests/integration/test_lease_extension.py -v -s + +Prerequisites: + - Conductor server running (default: http://localhost:6767/api) +""" + +import logging +import os +import sys +import time +import threading +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +from conductor.client.automator.task_handler import TaskHandler +from conductor.client.configuration.configuration import Configuration +from conductor.client.worker.worker_task import worker_task +from conductor.client.http.models.workflow_def import WorkflowDef +from conductor.client.http.models.task_def import TaskDef +from conductor.client.http.models.workflow_task import WorkflowTask +from conductor.client.http.models.start_workflow_request import StartWorkflowRequest +from conductor.client.orkes.orkes_workflow_client import OrkesWorkflowClient +from conductor.client.orkes.orkes_metadata_client import OrkesMetadataClient + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +# Short response timeout — task must heartbeat to stay alive +RESPONSE_TIMEOUT_SECONDS = 10 + +# Task sleeps longer than the response timeout to prove heartbeat works. +# Must be long enough that the server's workflow sweeper (which runs every +# ~30s) catches the expired task BEFORE the worker completes. +TASK_SLEEP_SECONDS = 50 + + +# -- Workers ----------------------------------------------------------------- + +# Worker WITH lease extension enabled — heartbeats keep it alive +@worker_task( + task_definition_name='lease_heartbeat_task', + lease_extend_enabled=True, + register_task_def=True, + task_def=TaskDef( + name='lease_heartbeat_task', + response_timeout_seconds=RESPONSE_TIMEOUT_SECONDS, + timeout_seconds=180, + retry_count=0, + ), + overwrite_task_def=True, +) +def lease_heartbeat_task(job_id: str) -> dict: + """Long-running task with heartbeat — should complete.""" + logger.info("[heartbeat_task] Starting job %s, sleeping %ss (timeout=%ss)", + job_id, TASK_SLEEP_SECONDS, RESPONSE_TIMEOUT_SECONDS) + time.sleep(TASK_SLEEP_SECONDS) + logger.info("[heartbeat_task] Completed job %s", job_id) + return {'job_id': job_id, 'status': 'completed', 'slept': TASK_SLEEP_SECONDS} + + +# Worker WITHOUT lease extension — will time out +@worker_task( + task_definition_name='lease_no_heartbeat_task', + lease_extend_enabled=False, + register_task_def=True, + task_def=TaskDef( + name='lease_no_heartbeat_task', + response_timeout_seconds=RESPONSE_TIMEOUT_SECONDS, + timeout_seconds=120, + retry_count=0, + ), + overwrite_task_def=True, +) +def lease_no_heartbeat_task(job_id: str) -> dict: + """Long-running task without heartbeat — should time out.""" + logger.info("[no_heartbeat_task] Starting job %s, sleeping %ss (timeout=%ss)", + job_id, TASK_SLEEP_SECONDS, RESPONSE_TIMEOUT_SECONDS) + time.sleep(TASK_SLEEP_SECONDS) + logger.info("[no_heartbeat_task] Completed job %s", job_id) + return {'job_id': job_id, 'status': 'completed', 'slept': TASK_SLEEP_SECONDS} + + +# -- Test class -------------------------------------------------------------- + +class TestLeaseExtension(unittest.TestCase): + + @classmethod + def setUpClass(cls): + from tests.integration.conftest import skip_if_server_unavailable + skip_if_server_unavailable() + + cls.config = Configuration() + cls.metadata_client = OrkesMetadataClient(cls.config) + cls.workflow_client = OrkesWorkflowClient(cls.config) + + def _register_workflow(self, wf_name, task_name): + """Register a single-task workflow.""" + workflow = WorkflowDef(name=wf_name, version=1) + task = WorkflowTask( + name=task_name, + task_reference_name=f'{task_name}_ref', + input_parameters={'job_id': '${workflow.input.job_id}'}, + ) + workflow._tasks = [task] + try: + self.metadata_client.update_workflow_def(workflow, overwrite=True) + except Exception: + self.metadata_client.register_workflow_def(workflow, overwrite=True) + logger.info("Registered workflow: %s", wf_name) + + def _start_workflow(self, wf_name, job_id): + """Start a workflow and return the execution ID.""" + req = StartWorkflowRequest() + req.name = wf_name + req.version = 1 + req.input = {'job_id': job_id} + wf_id = self.workflow_client.start_workflow(start_workflow_request=req) + logger.info("Started workflow %s: %s", wf_name, wf_id) + return wf_id + + def _wait_for_workflow(self, wf_id, timeout_seconds=60): + """Poll until workflow reaches a terminal state.""" + for i in range(timeout_seconds): + wf = self.workflow_client.get_workflow(wf_id, include_tasks=True) + if wf.status in ('COMPLETED', 'FAILED', 'TIMED_OUT', 'TERMINATED'): + return wf + time.sleep(1) + # Return whatever state it's in after timeout + return self.workflow_client.get_workflow(wf_id, include_tasks=True) + + def _run_workers_in_background(self, duration_seconds=60): + """Start workers in a background thread, return stop function.""" + handler = TaskHandler( + configuration=self.config, + scan_for_annotated_workers=True, + ) + handler.start_processes() + + def stop(): + handler.stop_processes() + + # Auto-stop after duration + timer = threading.Timer(duration_seconds, stop) + timer.daemon = True + timer.start() + + return stop + + def test_01_with_heartbeat_completes(self): + """Task WITH lease_extend_enabled=True completes even when sleep > responseTimeout.""" + print("\n" + "=" * 80) + print("TEST: With heartbeat — task should COMPLETE") + print(f" responseTimeoutSeconds={RESPONSE_TIMEOUT_SECONDS}s, task sleeps {TASK_SLEEP_SECONDS}s") + print("=" * 80) + + wf_name = 'test_lease_heartbeat' + self._register_workflow(wf_name, 'lease_heartbeat_task') + + stop_workers = self._run_workers_in_background(duration_seconds=90) + time.sleep(3) # let workers start + + try: + wf_id = self._start_workflow(wf_name, 'HEARTBEAT-001') + wf = self._wait_for_workflow(wf_id, timeout_seconds=80) + + print(f"\n Final status: {wf.status}") + for task in (wf.tasks or []): + print(f" Task {task.task_def_name}: {task.status}") + + self.assertEqual(wf.status, 'COMPLETED', + f"Workflow should COMPLETE with heartbeat, got {wf.status}") + + # Verify task output + tasks_by_ref = {t.reference_task_name: t for t in wf.tasks} + task = tasks_by_ref.get('lease_heartbeat_task_ref') + self.assertIsNotNone(task) + self.assertEqual(task.status, 'COMPLETED') + self.assertEqual(task.output_data.get('job_id'), 'HEARTBEAT-001') + self.assertEqual(task.output_data.get('slept'), TASK_SLEEP_SECONDS) + print("\n PASS: Task completed with heartbeat keeping lease alive") + finally: + stop_workers() + + def test_02_without_heartbeat_times_out(self): + """Task WITHOUT lease_extend_enabled times out when sleep > responseTimeout.""" + print("\n" + "=" * 80) + print("TEST: Without heartbeat — task should TIME OUT") + print(f" responseTimeoutSeconds={RESPONSE_TIMEOUT_SECONDS}s, task sleeps {TASK_SLEEP_SECONDS}s") + print("=" * 80) + + wf_name = 'test_lease_no_heartbeat' + self._register_workflow(wf_name, 'lease_no_heartbeat_task') + + stop_workers = self._run_workers_in_background(duration_seconds=90) + time.sleep(3) # let workers start + + try: + wf_id = self._start_workflow(wf_name, 'NO-HEARTBEAT-001') + wf = self._wait_for_workflow(wf_id, timeout_seconds=80) + + print(f"\n Final status: {wf.status}") + for task in (wf.tasks or []): + print(f" Task {task.task_def_name}: {task.status}") + + # Without heartbeat, the task should timeout or fail + self.assertIn(wf.status, ('FAILED', 'TIMED_OUT'), + f"Workflow should FAIL/TIMEOUT without heartbeat, got {wf.status}") + + tasks_by_ref = {t.reference_task_name: t for t in wf.tasks} + task = tasks_by_ref.get('lease_no_heartbeat_task_ref') + self.assertIsNotNone(task) + self.assertIn(task.status, ('TIMED_OUT', 'FAILED', 'CANCELED'), + f"Task should be TIMED_OUT/FAILED, got {task.status}") + print("\n PASS: Task timed out as expected without heartbeat") + finally: + stop_workers() + + +if __name__ == '__main__': + unittest.main() From cf1d7f0255e27f400096d5d542b87c89d854dd35 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Sat, 18 Apr 2026 14:56:37 -0700 Subject: [PATCH 2/4] fix(lease): async task lease leak, thread safety, and code review fixes MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix lease leak: when __execute_task returns None (ASYNC_TASK_RUNNING), keep the lease tracked instead of untracking in the finally block. Untrack in __check_completed_async_tasks when the async task finishes. - Remove dead isinstance(task_result, TaskInProgress) checks — __execute_task always wraps TaskInProgress into TaskResult before returning. - Add threading.Lock around _lease_info dict for PEP 703 free-threaded Python compatibility (safe beyond CPython GIL). - Log exception cause in _send_heartbeat retry failures (was swallowed). - Extract LeaseInfo dataclass and constants into shared lease_tracker.py module to eliminate duplication between task_runner and async_task_runner. Co-Authored-By: Claude Opus 4.6 --- .../client/automator/async_task_runner.py | 32 +++------- .../client/automator/lease_tracker.py | 17 ++++++ src/conductor/client/automator/task_runner.py | 60 +++++++++---------- 3 files changed, 54 insertions(+), 55 deletions(-) create mode 100644 src/conductor/client/automator/lease_tracker.py diff --git a/src/conductor/client/automator/async_task_runner.py b/src/conductor/client/automator/async_task_runner.py index da8b81f4f..e07596f35 100644 --- a/src/conductor/client/automator/async_task_runner.py +++ b/src/conductor/client/automator/async_task_runner.py @@ -5,7 +5,6 @@ import sys import time import traceback -from dataclasses import dataclass from conductor.client.configuration.configuration import Configuration from conductor.client.configuration.settings.metrics_settings import MetricsSettings @@ -33,6 +32,7 @@ from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline from conductor.client.worker.exception import NonRetryableException from conductor.client.automator.json_schema_generator import generate_json_schema_from_function +from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -40,20 +40,6 @@ ) ) -# Lease extension constants (matches Java SDK) -LEASE_EXTEND_RETRY_COUNT = 3 -LEASE_EXTEND_DURATION_FACTOR = 0.8 - - -@dataclass -class _LeaseInfo: - """Tracks when a heartbeat is next due for an in-flight task.""" - task_id: str - workflow_instance_id: str - response_timeout_seconds: float - last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) - interval_seconds: float # 80% of responseTimeoutSeconds - class AsyncTaskRunner: """ @@ -127,7 +113,7 @@ def __init__( self._semaphore = None self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint - self._lease_info = {} # task_id -> _LeaseInfo for lease extension heartbeats + self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats async def run(self) -> None: """Main async loop - runs continuously in single event loop.""" @@ -599,10 +585,6 @@ async def __async_execute_and_update_task(self, task: Task) -> None: try: while task is not None and not self._shutdown: task_result = await self.__async_execute_task(task) - # If task returned TaskInProgress, don't update yet - if isinstance(task_result, TaskInProgress): - logger.debug("Task %s is in progress, will update when complete", task.task_id) - return if task_result is None: return self._untrack_lease(task.task_id) @@ -949,7 +931,7 @@ def _track_lease(self, task) -> None: interval = timeout * LEASE_EXTEND_DURATION_FACTOR if interval < 1: return - self._lease_info[task.task_id] = _LeaseInfo( + self._lease_info[task.task_id] = LeaseInfo( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, response_timeout_seconds=timeout, @@ -979,7 +961,7 @@ async def _send_due_heartbeats(self) -> None: await self._send_heartbeat(info) info.last_heartbeat_time = time.monotonic() - async def _send_heartbeat(self, info: _LeaseInfo) -> None: + async def _send_heartbeat(self, info: LeaseInfo) -> None: """Send a single lease extension heartbeat with retry (async).""" result = TaskResult( task_id=info.task_id, @@ -991,13 +973,13 @@ async def _send_heartbeat(self, info: _LeaseInfo) -> None: await self.async_task_client.update_task(body=result) logger.debug("Extended lease for task %s", info.task_id) return - except Exception: + except Exception as e: if attempt < LEASE_EXTEND_RETRY_COUNT - 1: await asyncio.sleep(0.5 * (attempt + 2)) else: logger.error( - "Failed to extend lease for task %s after %d attempts", - info.task_id, LEASE_EXTEND_RETRY_COUNT, + "Failed to extend lease for task %s after %d attempts: %s", + info.task_id, LEASE_EXTEND_RETRY_COUNT, e, ) # -------------------------------------------------------------------------- diff --git a/src/conductor/client/automator/lease_tracker.py b/src/conductor/client/automator/lease_tracker.py new file mode 100644 index 000000000..794e54e2c --- /dev/null +++ b/src/conductor/client/automator/lease_tracker.py @@ -0,0 +1,17 @@ +"""Shared lease extension (heartbeat) tracking for TaskRunner and AsyncTaskRunner.""" + +from dataclasses import dataclass + +# Lease extension constants (matches Java SDK) +LEASE_EXTEND_RETRY_COUNT = 3 +LEASE_EXTEND_DURATION_FACTOR = 0.8 + + +@dataclass +class LeaseInfo: + """Tracks when a heartbeat is next due for an in-flight task.""" + task_id: str + workflow_instance_id: str + response_timeout_seconds: float + last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) + interval_seconds: float # 80% of responseTimeoutSeconds diff --git a/src/conductor/client/automator/task_runner.py b/src/conductor/client/automator/task_runner.py index 2949d5afe..f14f5994c 100644 --- a/src/conductor/client/automator/task_runner.py +++ b/src/conductor/client/automator/task_runner.py @@ -3,9 +3,10 @@ import os import sys import time +import threading import traceback from concurrent.futures import ThreadPoolExecutor, as_completed -from dataclasses import dataclass, field + from typing import List, Optional, Any from conductor.client.configuration.configuration import Configuration @@ -35,6 +36,7 @@ from conductor.client.worker.worker_config import resolve_worker_config, get_worker_config_oneline from conductor.client.worker.exception import NonRetryableException from conductor.client.automator.json_schema_generator import generate_json_schema_from_function +from conductor.client.automator.lease_tracker import LeaseInfo, LEASE_EXTEND_RETRY_COUNT, LEASE_EXTEND_DURATION_FACTOR logger = logging.getLogger( Configuration.get_logging_formatted_name( @@ -42,20 +44,6 @@ ) ) -# Lease extension constants (matches Java SDK) -LEASE_EXTEND_RETRY_COUNT = 3 -LEASE_EXTEND_DURATION_FACTOR = 0.8 - - -@dataclass -class _LeaseInfo: - """Tracks when a heartbeat is next due for an in-flight task.""" - task_id: str - workflow_instance_id: str - response_timeout_seconds: float - last_heartbeat_time: float # time.monotonic() of last heartbeat (or task start) - interval_seconds: float # 80% of responseTimeoutSeconds - class TaskRunner: def __init__( @@ -108,7 +96,8 @@ def __init__( self._consecutive_empty_polls = 0 # Track empty polls to implement backoff self._shutdown = False # Flag to indicate graceful shutdown self._use_update_v2 = True # Will be set to False if server doesn't support v2 endpoint - self._lease_info = {} # task_id -> _LeaseInfo for lease extension heartbeats + self._lease_info = {} # task_id -> LeaseInfo for lease extension heartbeats + self._lease_lock = threading.Lock() # Protects _lease_info for free-threaded Python def run(self) -> None: if self.configuration is not None: @@ -149,7 +138,8 @@ def _cleanup(self) -> None: logger.debug("Cleaning up TaskRunner resources...") # Stop all lease extension tracking - self._lease_info.clear() + with self._lease_lock: + self._lease_info.clear() # Shutdown ThreadPoolExecutor (EAFP style - more Pythonic) try: @@ -487,6 +477,9 @@ def __check_completed_async_tasks(self) -> None: for task_id, task_result, submit_time, task in completed: try: + # Async task finished — stop heartbeating for it + self._untrack_lease(task_id) + # Calculate actual execution time (from submission to completion) finish_time = time.time() time_spent = finish_time - submit_time @@ -534,16 +527,16 @@ def __execute_and_update_task(self, task: Task) -> None: or shutdown is requested. """ self._track_lease(task) + async_running = False # True when task is running async in background try: while task is not None and not self._shutdown: task_result = self.__execute_task(task) - # If task returned None, it's an async task running in background - don't update yet + # If task returned None, it's an async task running in background. + # Keep the lease tracked — __check_completed_async_tasks will untrack + # when the async task finishes. if task_result is None: logger.debug("Task %s is running async, will update when complete", task.task_id) - return - # If task returned TaskInProgress, it's running async - don't update yet - if isinstance(task_result, TaskInProgress): - logger.debug("Task %s is in progress, will update when complete", task.task_id) + async_running = True return self._untrack_lease(task.task_id) # Update task and get next task from v2 response @@ -562,7 +555,9 @@ def __execute_and_update_task(self, task: Task) -> None: traceback.format_exc() ) finally: - if task is not None: + # Don't untrack if the task is still running async in the background — + # the lease must stay active until __check_completed_async_tasks handles it. + if task is not None and not async_running: self._untrack_lease(task.task_id) def __batch_poll_tasks(self, count: int) -> list: @@ -980,13 +975,15 @@ def _track_lease(self, task: Task) -> None: interval = timeout * LEASE_EXTEND_DURATION_FACTOR if interval < 1: return - self._lease_info[task.task_id] = _LeaseInfo( + info = LeaseInfo( task_id=task.task_id, workflow_instance_id=task.workflow_instance_id, response_timeout_seconds=timeout, last_heartbeat_time=time.monotonic(), interval_seconds=interval, ) + with self._lease_lock: + self._lease_info[task.task_id] = info logger.debug( "Tracking lease for task %s (timeout=%ss, heartbeat every %ss)", task.task_id, timeout, interval, @@ -994,7 +991,8 @@ def _track_lease(self, task: Task) -> None: def _untrack_lease(self, task_id: str) -> None: """Stop tracking a task for lease extension.""" - removed = self._lease_info.pop(task_id, None) + with self._lease_lock: + removed = self._lease_info.pop(task_id, None) if removed is not None: logger.debug("Untracked lease for task %s", task_id) @@ -1003,14 +1001,16 @@ def _send_due_heartbeats(self) -> None: if not self._lease_info: return now = time.monotonic() - for info in list(self._lease_info.values()): + with self._lease_lock: + infos = list(self._lease_info.values()) + for info in infos: elapsed = now - info.last_heartbeat_time if elapsed < info.interval_seconds: continue self._send_heartbeat(info) info.last_heartbeat_time = time.monotonic() - def _send_heartbeat(self, info: _LeaseInfo) -> None: + def _send_heartbeat(self, info: LeaseInfo) -> None: """Send a single lease extension heartbeat with retry.""" result = TaskResult( task_id=info.task_id, @@ -1022,13 +1022,13 @@ def _send_heartbeat(self, info: _LeaseInfo) -> None: self.task_client.update_task(body=result) logger.debug("Extended lease for task %s", info.task_id) return - except Exception: + except Exception as e: if attempt < LEASE_EXTEND_RETRY_COUNT - 1: time.sleep(0.5 * (attempt + 2)) else: logger.error( - "Failed to extend lease for task %s after %d attempts", - info.task_id, LEASE_EXTEND_RETRY_COUNT, + "Failed to extend lease for task %s after %d attempts: %s", + info.task_id, LEASE_EXTEND_RETRY_COUNT, e, ) # -------------------------------------------------------------------------- From c33f9b8284795e432a51ce8e3504fdff29b02cce Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Sat, 18 Apr 2026 15:18:40 -0700 Subject: [PATCH 3/4] fix(http): break httpx Response reference cycle causing memory leak (#395) RESTResponse now eagerly reads resp.text into self.data and breaks the httpx Response <-> BoundSyncStream cycle by nulling resp.stream and resp._request. Drops io.IOBase inheritance (removes __del__ finalizer overhead). Removes write-only self.last_response retention. Adds json(), getheader() convenience methods to RESTResponse. Changes applied to both sync (rest.py, api_client.py) and async (async_rest.py, async_api_client.py) codepaths. Fixes #395 Co-Authored-By: Claude Opus 4.6 --- src/conductor/client/http/api_client.py | 6 +- src/conductor/client/http/async_api_client.py | 6 +- src/conductor/client/http/async_rest.py | 53 ++++++------- src/conductor/client/http/rest.py | 53 ++++++------- tests/unit/api_client/repro_memory_leak.py | 69 ++++++++++++++++ .../api_client/test_api_client_coverage.py | 8 +- tests/unit/api_client/test_memory_leak.py | 78 +++++++++++++++++++ 7 files changed, 203 insertions(+), 70 deletions(-) create mode 100644 tests/unit/api_client/repro_memory_leak.py create mode 100644 tests/unit/api_client/test_memory_leak.py diff --git a/src/conductor/client/http/api_client.py b/src/conductor/client/http/api_client.py index 4cc321922..761f78e28 100644 --- a/src/conductor/client/http/api_client.py +++ b/src/conductor/client/http/api_client.py @@ -173,8 +173,6 @@ def __call_api_no_retry( _preload_content=_preload_content, _request_timeout=_request_timeout) - self.last_response = response_data - return_data = response_data if _preload_content: # deserialize response data @@ -266,9 +264,9 @@ def deserialize(self, response, response_type): # fetch data from response object try: - data = response.resp.json() + data = response.json() except Exception: - data = response.resp.text + data = response.data try: return self.__deserialize(data, response_type) diff --git a/src/conductor/client/http/async_api_client.py b/src/conductor/client/http/async_api_client.py index 90bdf2674..3606573a2 100644 --- a/src/conductor/client/http/async_api_client.py +++ b/src/conductor/client/http/async_api_client.py @@ -184,8 +184,6 @@ async def __call_api_no_retry( _preload_content=_preload_content, _request_timeout=_request_timeout) - self.last_response = response_data - return_data = response_data if _preload_content: # deserialize response data @@ -277,9 +275,9 @@ def deserialize(self, response, response_type): # fetch data from response object try: - data = response.resp.json() + data = response.json() except Exception: - data = response.resp.text + data = response.data try: return self.__deserialize(data, response_type) diff --git a/src/conductor/client/http/async_rest.py b/src/conductor/client/http/async_rest.py index 9fb948eff..d6393c52f 100644 --- a/src/conductor/client/http/async_rest.py +++ b/src/conductor/client/http/async_rest.py @@ -1,4 +1,3 @@ -import io import json import os import re @@ -7,39 +6,35 @@ from six.moves.urllib.parse import urlencode -class RESTResponse(io.IOBase): +class RESTResponse: def __init__(self, resp): self.status = resp.status_code - # httpx.Response doesn't have reason attribute, derive it from status_code - self.reason = resp.reason_phrase if hasattr(resp, 'reason_phrase') else self._get_reason_phrase(resp.status_code) - self.resp = resp + self.reason = getattr(resp, 'reason_phrase', '') or self._get_reason_phrase(resp.status_code) + self.data = resp.text # eagerly read body self.headers = resp.headers + # Break httpx Response <-> BoundSyncStream reference cycle (issue #395) + resp.stream = None + resp._request = None def _get_reason_phrase(self, status_code): """Get HTTP reason phrase from status code.""" phrases = { - 200: 'OK', - 201: 'Created', - 202: 'Accepted', - 204: 'No Content', - 301: 'Moved Permanently', - 302: 'Found', - 304: 'Not Modified', - 400: 'Bad Request', - 401: 'Unauthorized', - 403: 'Forbidden', - 404: 'Not Found', - 405: 'Method Not Allowed', - 409: 'Conflict', - 429: 'Too Many Requests', - 500: 'Internal Server Error', - 502: 'Bad Gateway', - 503: 'Service Unavailable', - 504: 'Gateway Timeout', + 200: 'OK', 201: 'Created', 202: 'Accepted', 204: 'No Content', + 301: 'Moved Permanently', 302: 'Found', 304: 'Not Modified', + 400: 'Bad Request', 401: 'Unauthorized', 403: 'Forbidden', + 404: 'Not Found', 405: 'Method Not Allowed', 409: 'Conflict', + 429: 'Too Many Requests', 500: 'Internal Server Error', + 502: 'Bad Gateway', 503: 'Service Unavailable', 504: 'Gateway Timeout', } return phrases.get(status_code, 'Unknown') + def json(self): + return json.loads(self.data) + + def getheader(self, name, default=None): + return self.headers.get(name, default) + def getheaders(self): return self.headers @@ -283,15 +278,15 @@ def __init__(self, status=None, reason=None, http_resp=None, body=None): self.status = http_resp.status self.code = http_resp.status self.reason = http_resp.reason - self.body = http_resp.resp.text + self.body = http_resp.data try: - if http_resp.resp.text: - error = json.loads(http_resp.resp.text) + if http_resp.data: + error = json.loads(http_resp.data) self.message = error['message'] else: - self.message = http_resp.resp.text + self.message = http_resp.data except Exception as e: - self.message = http_resp.resp.text + self.message = http_resp.data self.headers = http_resp.getheaders() else: self.status = status @@ -324,7 +319,7 @@ def is_not_found(self) -> bool: class AuthorizationException(ApiException): def __init__(self, status=None, reason=None, http_resp=None, body=None): try: - data = json.loads(http_resp.resp.text) + data = json.loads(http_resp.data) if 'error' in data: self._error_code = data['error'] else: diff --git a/src/conductor/client/http/rest.py b/src/conductor/client/http/rest.py index aedcbc952..8fbce2d1b 100644 --- a/src/conductor/client/http/rest.py +++ b/src/conductor/client/http/rest.py @@ -1,4 +1,3 @@ -import io import json import os import re @@ -7,39 +6,35 @@ from six.moves.urllib.parse import urlencode -class RESTResponse(io.IOBase): +class RESTResponse: def __init__(self, resp): self.status = resp.status_code - # httpx.Response doesn't have reason attribute, derive it from status_code - self.reason = resp.reason_phrase if hasattr(resp, 'reason_phrase') else self._get_reason_phrase(resp.status_code) - self.resp = resp + self.reason = getattr(resp, 'reason_phrase', '') or self._get_reason_phrase(resp.status_code) + self.data = resp.text # eagerly read body self.headers = resp.headers + # Break httpx Response <-> BoundSyncStream reference cycle (issue #395) + resp.stream = None + resp._request = None def _get_reason_phrase(self, status_code): """Get HTTP reason phrase from status code.""" phrases = { - 200: 'OK', - 201: 'Created', - 202: 'Accepted', - 204: 'No Content', - 301: 'Moved Permanently', - 302: 'Found', - 304: 'Not Modified', - 400: 'Bad Request', - 401: 'Unauthorized', - 403: 'Forbidden', - 404: 'Not Found', - 405: 'Method Not Allowed', - 409: 'Conflict', - 429: 'Too Many Requests', - 500: 'Internal Server Error', - 502: 'Bad Gateway', - 503: 'Service Unavailable', - 504: 'Gateway Timeout', + 200: 'OK', 201: 'Created', 202: 'Accepted', 204: 'No Content', + 301: 'Moved Permanently', 302: 'Found', 304: 'Not Modified', + 400: 'Bad Request', 401: 'Unauthorized', 403: 'Forbidden', + 404: 'Not Found', 405: 'Method Not Allowed', 409: 'Conflict', + 429: 'Too Many Requests', 500: 'Internal Server Error', + 502: 'Bad Gateway', 503: 'Service Unavailable', 504: 'Gateway Timeout', } return phrases.get(status_code, 'Unknown') + def json(self): + return json.loads(self.data) + + def getheader(self, name, default=None): + return self.headers.get(name, default) + def getheaders(self): return self.headers @@ -290,15 +285,15 @@ def __init__(self, status=None, reason=None, http_resp=None, body=None): self.status = http_resp.status self.code = http_resp.status self.reason = http_resp.reason - self.body = http_resp.resp.text + self.body = http_resp.data try: - if http_resp.resp.text: - error = json.loads(http_resp.resp.text) + if http_resp.data: + error = json.loads(http_resp.data) self.message = error['message'] else: - self.message = http_resp.resp.text + self.message = http_resp.data except Exception as e: - self.message = http_resp.resp.text + self.message = http_resp.data self.headers = http_resp.getheaders() else: self.status = status @@ -332,7 +327,7 @@ def is_not_found(self) -> bool: class AuthorizationException(ApiException): def __init__(self, status=None, reason=None, http_resp=None, body=None): try: - data = json.loads(http_resp.resp.text) + data = json.loads(http_resp.data) if 'error' in data: self._error_code = data['error'] else: diff --git a/tests/unit/api_client/repro_memory_leak.py b/tests/unit/api_client/repro_memory_leak.py new file mode 100644 index 000000000..4fc757f7f --- /dev/null +++ b/tests/unit/api_client/repro_memory_leak.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Standalone reproduction script for issue #395. +Run before and after the fix to see the difference. + +Usage: + python tests/unit/api_client/repro_memory_leak.py + +Before fix: memory grows ~400+ KB over 2000 requests +After fix: memory stays flat (< 50 KB growth) +""" +import gc +import tracemalloc + +import httpx + +from conductor.client.http.rest import RESTResponse + + +class _EchoTransport(httpx.BaseTransport): + def handle_request(self, request: httpx.Request) -> httpx.Response: + return httpx.Response(200, content=b'{"status":"ok"}') + + +def main(): + tracemalloc.start() + client = httpx.Client(transport=_EchoTransport()) + + # Warm up + for _ in range(100): + r = client.get("http://test/poll") + resp = RESTResponse(r) + _ = resp.data + del r, resp + gc.collect() + + snapshot_before = tracemalloc.take_snapshot() + + # Simulate 2000 poll cycles (~ 30 min of a real worker at 1 req/s) + for i in range(2000): + r = client.get("http://test/poll") + resp = RESTResponse(r) + _ = resp.data + del r, resp + + gc.collect() + snapshot_after = tracemalloc.take_snapshot() + + client.close() + + stats = snapshot_after.compare_to(snapshot_before, 'lineno') + + total_growth = sum(s.size_diff for s in stats if s.size_diff > 0) + print(f"\nTotal memory growth after 2000 requests: {total_growth / 1024:.1f} KB") + + if total_growth > 50 * 1024: + print("LEAK DETECTED - growth exceeds 50 KB threshold") + print("\nTop allocations:") + for s in stats[:10]: + if s.size_diff > 0: + print(f" {s}") + return 1 + else: + print("OK - no significant leak detected") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/tests/unit/api_client/test_api_client_coverage.py b/tests/unit/api_client/test_api_client_coverage.py index 1ec78978c..7ce0ffdca 100644 --- a/tests/unit/api_client/test_api_client_coverage.py +++ b/tests/unit/api_client/test_api_client_coverage.py @@ -182,7 +182,7 @@ def test_deserialize_with_json_response(self): # Mock response with JSON response = Mock() - response.resp.json.return_value = {'key': 'value'} + response.json.return_value = {'key': 'value'} result = client.deserialize(response, 'dict(str, str)') self.assertEqual(result, {'key': 'value'}) @@ -194,8 +194,8 @@ def test_deserialize_with_text_response(self): # Mock response that fails JSON parsing response = Mock() - response.resp.json.side_effect = Exception("Not JSON") - response.resp.text = "plain text" + response.json.side_effect = Exception("Not JSON") + response.data = "plain text" with patch.object(client, '_ApiClient__deserialize', return_value="deserialized") as mock_deserialize: result = client.deserialize(response, 'str') @@ -207,7 +207,7 @@ def test_deserialize_with_value_error(self): client = ApiClient(configuration=self.config) response = Mock() - response.resp.json.return_value = {'key': 'value'} + response.json.return_value = {'key': 'value'} with patch.object(client, '_ApiClient__deserialize', side_effect=ValueError("Invalid")): result = client.deserialize(response, 'SomeClass') diff --git a/tests/unit/api_client/test_memory_leak.py b/tests/unit/api_client/test_memory_leak.py new file mode 100644 index 000000000..1d82f2c47 --- /dev/null +++ b/tests/unit/api_client/test_memory_leak.py @@ -0,0 +1,78 @@ +""" +Reproduction and regression test for GitHub issue #395: +httpx.Response objects leak due to reference cycle in BoundSyncStream. + +The test creates real httpx responses (via httpx.Client against a local +transport), wraps them in RESTResponse, and verifies the httpx.Response +is eligible for garbage collection after the RESTResponse is consumed. +""" +import gc +import io +import weakref +import unittest + +import httpx + +from conductor.client.http.rest import RESTResponse + + +class _EchoTransport(httpx.BaseTransport): + """Returns a small JSON body for every request - no network needed.""" + + def handle_request(self, request: httpx.Request) -> httpx.Response: + return httpx.Response( + status_code=200, + headers={"content-type": "application/json"}, + content=b'{"ok": true}', + ) + + +class TestHttpxResponseMemoryLeak(unittest.TestCase): + """Regression: RESTResponse must not prevent httpx.Response GC.""" + + def test_httpx_response_does_not_leak(self): + """After wrapping in RESTResponse the raw httpx.Response must be GC-able.""" + client = httpx.Client(transport=_EchoTransport()) + + refs = [] + for _ in range(50): + raw = client.get("http://test/ping") + refs.append(weakref.ref(raw)) + rest_resp = RESTResponse(raw) + # Simulate what api_client does: read body then discard + _ = rest_resp.data + del raw, rest_resp + + # Force full collection (including cyclic GC) + gc.collect() + + alive = sum(1 for r in refs if r() is not None) + # Before the fix, all 50 would be alive. + # After the fix, none (or very few due to GC timing) should remain. + self.assertLessEqual(alive, 2, f"{alive}/50 httpx.Response objects still alive - leak not fixed") + + client.close() + + def test_rest_response_attributes(self): + """RESTResponse exposes .data, .json(), .getheader(), .getheaders().""" + client = httpx.Client(transport=_EchoTransport()) + raw = client.get("http://test/ping") + resp = RESTResponse(raw) + + self.assertEqual(resp.status, 200) + self.assertEqual(resp.data, '{"ok": true}') + self.assertEqual(resp.json(), {"ok": True}) + self.assertEqual(resp.getheader("content-type"), "application/json") + self.assertIsNotNone(resp.getheaders()) + # After construction, the raw response should not be retained + self.assertFalse(hasattr(resp, 'resp')) + + client.close() + + def test_no_io_base_inheritance(self): + """RESTResponse must not inherit from io.IOBase (avoids __del__ overhead).""" + client = httpx.Client(transport=_EchoTransport()) + raw = client.get("http://test/ping") + resp = RESTResponse(raw) + self.assertNotIsInstance(resp, io.IOBase) + client.close() From 2c3e9fe762f5b08858302282cc16df7f60c6edd3 Mon Sep 17 00:00:00 2001 From: Viren Baraiya Date: Sat, 18 Apr 2026 16:17:12 -0700 Subject: [PATCH 4/4] fix(http): guard private httpx attr, fix async comment, fix test mocks MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Wrap resp._request = None in try/except AttributeError for forward compat with future httpx versions - Fix async_rest.py comment: BoundSyncStream → BoundAsyncStream - Fix test_task_runner_coverage.py mocks: use mock_http_resp.data instead of dead mock_http_resp.resp (AuthorizationException now reads .data, not .resp.text) Follow-up to #395 Co-Authored-By: Claude Opus 4.6 --- src/conductor/client/http/async_rest.py | 7 +++-- src/conductor/client/http/rest.py | 5 +++- .../automator/test_task_runner_coverage.py | 30 ++++++++----------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/src/conductor/client/http/async_rest.py b/src/conductor/client/http/async_rest.py index d6393c52f..ecf777c8c 100644 --- a/src/conductor/client/http/async_rest.py +++ b/src/conductor/client/http/async_rest.py @@ -13,9 +13,12 @@ def __init__(self, resp): self.reason = getattr(resp, 'reason_phrase', '') or self._get_reason_phrase(resp.status_code) self.data = resp.text # eagerly read body self.headers = resp.headers - # Break httpx Response <-> BoundSyncStream reference cycle (issue #395) + # Break httpx Response <-> BoundAsyncStream reference cycle (issue #395) resp.stream = None - resp._request = None + try: + resp._request = None + except AttributeError: + pass def _get_reason_phrase(self, status_code): """Get HTTP reason phrase from status code.""" diff --git a/src/conductor/client/http/rest.py b/src/conductor/client/http/rest.py index 8fbce2d1b..b7e3d0b17 100644 --- a/src/conductor/client/http/rest.py +++ b/src/conductor/client/http/rest.py @@ -15,7 +15,10 @@ def __init__(self, resp): self.headers = resp.headers # Break httpx Response <-> BoundSyncStream reference cycle (issue #395) resp.stream = None - resp._request = None + try: + resp._request = None + except AttributeError: + pass def _get_reason_phrase(self, status_code): """Get HTTP reason phrase from status code.""" diff --git a/tests/unit/automator/test_task_runner_coverage.py b/tests/unit/automator/test_task_runner_coverage.py index e2d474bf2..cce83bcf0 100644 --- a/tests/unit/automator/test_task_runner_coverage.py +++ b/tests/unit/automator/test_task_runner_coverage.py @@ -316,15 +316,13 @@ def test_poll_task_auth_failure_with_invalid_token(self, mock_sleep): task_runner = TaskRunner(worker=worker) # Create mock response with INVALID_TOKEN error - mock_resp = Mock() - mock_resp.text = '{"error": "INVALID_TOKEN"}' - mock_http_resp = Mock() - mock_http_resp.resp = mock_resp + mock_http_resp.data = '{"error": "INVALID_TOKEN"}' + mock_http_resp.status = 401 + mock_http_resp.reason = 'Unauthorized' + mock_http_resp.getheaders.return_value = {} auth_exception = AuthorizationException( - status=401, - reason='Unauthorized', http_resp=mock_http_resp ) @@ -342,15 +340,13 @@ def test_poll_task_auth_failure_without_invalid_token(self, mock_sleep): task_runner = TaskRunner(worker=worker) # Create mock response with different error code - mock_resp = Mock() - mock_resp.text = '{"error": "FORBIDDEN"}' - mock_http_resp = Mock() - mock_http_resp.resp = mock_resp + mock_http_resp.data = '{"error": "FORBIDDEN"}' + mock_http_resp.status = 403 + mock_http_resp.reason = 'Forbidden' + mock_http_resp.getheaders.return_value = {} auth_exception = AuthorizationException( - status=403, - reason='Forbidden', http_resp=mock_http_resp ) @@ -420,15 +416,13 @@ def test_poll_task_with_metrics_on_auth_error(self): ) # Create mock response with INVALID_TOKEN error - mock_resp = Mock() - mock_resp.text = '{"error": "INVALID_TOKEN"}' - mock_http_resp = Mock() - mock_http_resp.resp = mock_resp + mock_http_resp.data = '{"error": "INVALID_TOKEN"}' + mock_http_resp.status = 401 + mock_http_resp.reason = 'Unauthorized' + mock_http_resp.getheaders.return_value = {} auth_exception = AuthorizationException( - status=401, - reason='Unauthorized', http_resp=mock_http_resp )