diff --git a/CHANGELOG.md b/CHANGELOG.md index 53b4824..4dae1a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,43 @@ All notable changes to ContextPilot will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.4.0] - 2026-03-29 + +### Added +- **Cloud prompt cache proxy** for Anthropic, OpenAI, and MiniMax — transparent prefix caching over cloud APIs +- **HTTP intercept proxy** — drop-in reverse proxy that extracts, reorders, and deduplicates documents in LLM requests without client changes +- **Block-level dedup** — content-defined chunking within tool results and assistant code blocks to deduplicate repeated content across turns +- **OpenClaw integration** — tool_result reordering, `markdown_header` extraction mode, deployment files, and quick-start guide +- **TTL-based cache eviction** policy with configurable tiers and automatic expiry +- **Conversation tracker** for multi-turn state: parent chain tracking, per-turn document history, and cross-turn block dedup +- `--chunk-modulus` CLI flag for tuning content-level dedup block size +- Cache sync documentation and `how_it_works.md` guide +- Pipeline diagram and architecture SVGs +- M5 MacBook Air results to Apple Silicon benchmark table +- P99 wall time to OpenClaw benchmark table + +### Changed +- Renamed dedup levels: file-level → document-level, block-level → content-level, content-level → ContextBlock-level +- Intercept parser supports multiple extraction formats (XML, numbered, separator, JSON results) with auto-detection +- Cloud adapters inject `cache_control` breakpoints on system prompts and tool results (limited to 4 per Anthropic API) +- Proxy forwards request metadata via headers instead of body to avoid breaking tool loops + +### Fixed +- Block dedup `"\n\n".join` corrupting content at chunk boundaries (phantom blank lines) +- `hash()` non-determinism in content-defined chunking — replaced with `hashlib.md5` +- `_chunk_modulus` missing from global declaration (CLI flag silently ignored) +- Proxy hardcoding `temperature=0` overwriting user values — now uses `setdefault` +- `default_ttl_seconds=0` silently becoming 300 (falsy `or` → `is not None`) +- `default_ttl` setter not syncing `_default_ttl_seconds` +- `update_from_response` double-counting partial cache hits +- Reconstruction functions using default config instead of original extraction config +- API key leak in error responses from `aiohttp.ClientError` +- Non-JSON upstream error crashing with `JSONDecodeError` +- Streaming connection leak on client disconnect (missing `finally` cleanup) +- Redundant `copy.deepcopy` doubling memory pressure per request +- Cycle detection added to `get_conversation_chain` +- Alpha header validation (non-numeric no longer crashes) + ## [0.3.5.post2] - 2026-03-05 ### Added diff --git a/README.md b/README.md index 69c8521..8115e47 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ ## News +- [2026/03] Supports [OpenClaw](https://openclaw.ai) — [guide](docs/guides/openclaw.md) | [benchmark](docs/benchmarks/openclaw.md) +- [2026/03] Supports cloud APIs (OpenAI, Anthropic, MiniMax) — [cache sync](docs/guides/cache_sync.md) - [2026/03] ContextPilot now can run on **macOS / Apple Silicon** via [llama.cpp](docs/guides/mac_llama_cpp.md). - [2026/02] ContextPilot v0.3.2 released, supporting [PageIndex](https://github.com/VectifyAI/PageIndex) and [Mem0](https://github.com/mem0ai/mem0). - [2026/01] ContextPilot has been accepted to MLSys 2026 🎉! See you in Bellevue, WA, USA. @@ -28,7 +30,7 @@ Long-context workloads (RAG, memory chat, tool-augmented agents) prepend many co ContextPilot sits between context assembly and inference to maximize prefix reuse and remove duplicates: 1. **Higher throughput & cache hits** — boosts prefill throughput and prefix cache hit ratio via context reuse. -2. **Drop-in solutions** — works with [PageIndex](https://github.com/VectifyAI/PageIndex), [Mem0](https://github.com/mem0ai/mem0), [LMCache](https://github.com/LMCache/LMCache), and backends like [vLLM](https://github.com/vllm-project/vllm) / [SGLang](https://github.com/sgl-project/sglang) / [llama.cpp](docs/guides/mac_llama_cpp.md). +2. **Drop-in solutions** — supports [OpenClaw](https://openclaw.ai) ([guide](docs/guides/openclaw.md)), [PageIndex](https://github.com/VectifyAI/PageIndex), [Mem0](https://github.com/mem0ai/mem0), [LMCache](https://github.com/LMCache/LMCache), [vLLM](https://github.com/vllm-project/vllm), [SGLang](https://github.com/sgl-project/sglang), [llama.cpp](docs/guides/mac_llama_cpp.md), and cloud APIs (OpenAI, Anthropic). 3. **No compromise in reasoning quality** — can even improve with extremely long contexts. 4. **Widely tested** — validated across diverse RAG and agentic workloads. @@ -42,53 +44,63 @@ It maintains a **Context Index** of cached content, then per request applies **R ## Performance at a Glance -ContextPilot is validated across three representative settings: single-node academic RAG, multi-node production MoE inference, and multi-turn memory-augmented chat. In every case it delivers significant speedups with comparable answer quality. +**OpenClaw Agent on RTX 5090** — 60 enterprise document analysis tasks ([claw-tasks](https://github.com/EfficientContext/ClawTasks)), Qwen3-4B-Instruct via SGLang. [Full results →](docs/benchmarks/openclaw.md) -**Qwen3-32B on 4×A6000** — single-node academic RAG with a 32B model on consumer GPUs. - -| Benchmark | Method | Prefill TP (tok/s) | Cache Hit | F1 (%) | -|-----------|--------|--------------------|-----------|--------| -| MultihopRAG | SGLang | 7,290 | 4.64% | 60.42 | -| | **SGLang + ContextPilot** | **14,214** | **33.97%** | **64.39** | -| NarrativeQA | SGLang | 7,921 | 5.91% | 28.41 | -| | **SGLang + ContextPilot** | **12,117** | **20.82%** | **29.64** | - -**DeepSeek-R1-671B on 16×H20** — production-scale 671B MoE inference on a multi-node GPU cluster. - -| Benchmark | Method | Prefill TP (tok/s) | Cache Hit | F1 (%) | -|-----------|--------|--------------------|-----------|--------| -| MultihopRAG | SGLang | 9,636 | 5.12% | 64.15 | -| | **SGLang + ContextPilot** | **17,498** | **60.37%** | **64.68** | -| NarrativeQA | SGLang | 8,687 | 6.08% | 40.20 | -| | **SGLang + ContextPilot** | **13,201** | **38.24%** | **41.08** | +| Metric | OpenClaw + SGLang | + ContextPilot | Δ | +|--------|-------------------|----------------|---| +| Prompt tokens / request (avg) | 45,771 | 33,622 | **-26.5%** | +| Prompt tokens / request (P99) | 92,785 | 51,581 | **-44.4%** | +| Wall time (avg) | 26.1s | 20.8s | **-20.4%** | +| Wall time (P99) | 68.8s | 50.4s | **-26.6%** | +| Accuracy | 245/245 | 245/245 | ✓ | **Qwen3-4B on 1×A6000** — multi-turn memory chat with [Mem0](https://github.com/mem0ai/mem0) on the [LoCoMo](https://github.com/snap-research/locomo) benchmark. | Context Size | Method | TTFT (s) | LLM Judge | |--------------|--------|----------|-----------| +| 5 (long context memory) | SGLang | 0.1051 | 0.418 | +| | **SGLang + ContextPilot** | **0.0548** | 0.414 | | 100 memories | SGLang | 0.1012 | 0.437 | | | **SGLang + ContextPilot** | **0.0554** | 0.420 | >ContextPilot results in mem0 table are without context annotation — an optional feature that adds original importance ranking to reordered context blocks, which can further improve answer quality (see [Paper](https://arxiv.org/abs/2511.03475)). -**Llama-3.2-1B on Apple M3 (MacBook Air, 16 GB)** — MultihopRAG on Apple Silicon with llama.cpp, no GPU server required. +**Llama-3.2-1B on Apple Silicon** — MultihopRAG with llama.cpp, no GPU server required. -| Method | Avg Latency (ms) | -|--------|-----------------| -| llama.cpp | 3,315 | -| **llama.cpp + ContextPilot** | **1,378** | +| Device | Method | Avg Latency (ms) | +|--------|--------|-----------------| +| M3 (MacBook Air, 16 GB) | llama.cpp | 3,315 | +| | **llama.cpp + ContextPilot** | **1,378** | +| M5 (MacBook Air, 32 GB) | llama.cpp | 2,157 | +| | **llama.cpp + ContextPilot** | **911** | Settings: `Llama-3.2-1B-Instruct-Q4_K_M.gguf`, Metal offload (`-ngl 99`), `--cache-reuse 256`, `--parallel 4`, context 32768 tokens. See the [Mac + llama.cpp guide](docs/guides/mac_llama_cpp.md). +We also evaluated on academic RAG (Qwen3-32B, 4×A6000) and production MoE inference (DeepSeek-R1-671B, 16×H20) — see [RAG benchmarks](docs/benchmarks/rag.md) and [paper](https://arxiv.org/abs/2511.03475). + ## Installation **Requirements:** Python >= 3.10 --- -### vLLM / SGLang +### OpenClaw -ContextPilot works with both CPU and GPU backends for building the context index. The `[gpu]` extra enables GPU-accelerated distance computation (via `cupy-cuda12x`) and is faster for large batches; without it, ContextPilot falls back to the CPU backend automatically. +```bash +pip install contextpilot + +# Start proxy (points to your LLM backend) +python -m contextpilot.server.http_server \ + --port 8765 --infer-api-url http://localhost:30000 # SGLang + # or: --infer-api-url https://api.anthropic.com # Anthropic + # or: --infer-api-url https://api.openai.com # OpenAI +``` + +Then set OpenClaw's base URL to `http://localhost:8765/v1`. See the [full OpenClaw integration guide](docs/guides/openclaw.md) for UI setup, config file examples, and self-hosted model instructions. + +--- + +### vLLM / SGLang **From PyPI** — the vLLM and SGLang hooks are installed automatically: ```bash @@ -135,6 +147,19 @@ Docker images are also available for both all-in-one and standalone deployment. ## Getting Started +### Quick Start with OpenClaw + +```bash +# Ask OpenClaw to analyze vendor contracts (ContextPilot deduplicates shared content automatically) +openclaw agent --message "Read contracts/contract_alpha_cloud.txt and summarize the liability terms." +openclaw agent --message "Read contracts/contract_beta_ai.txt and compare its liability with Alpha." +openclaw agent --message "Read contracts/contract_gamma_security.txt. Rank all three by liability exposure." +``` + +When the agent reads multiple documents sharing content (contracts from the same template, proposals with shared methodology), ContextPilot automatically deduplicates identical blocks — reducing prefill tokens by ~27% with zero accuracy loss. See the [integration guide](docs/guides/openclaw.md) and [benchmark](docs/benchmarks/openclaw.md). + +--- + ### Quick Start with Context Ordering Add **one call** (`cp_instance.optimize()`) before inference to rearrange context blocks so that shared content aligns into a common prefix, enabling cache reuse. An importance ranking in the prompt preserves accuracy. diff --git a/contextpilot/__init__.py b/contextpilot/__init__.py index 831f6cd..4a6b312 100644 --- a/contextpilot/__init__.py +++ b/contextpilot/__init__.py @@ -7,13 +7,13 @@ Quick Start: >>> from contextpilot.pipeline import RAGPipeline - >>> + >>> >>> pipeline = RAGPipeline( ... retriever="bm25", ... corpus_path="corpus.jsonl", ... model="Qwen/Qwen2.5-7B-Instruct" ... ) - >>> + >>> >>> results = pipeline.run(queries=["What is AI?"]) See docs/reference/api.md for detailed documentation. @@ -38,6 +38,12 @@ from .server.live_index import ContextPilot +from .dedup import ( + dedup_chat_completions, + dedup_responses_api, + DedupResult, +) + from .api import optimize, optimize_batch from .retriever import ( @@ -53,27 +59,28 @@ __all__ = [ # High-level pipeline API - 'RAGPipeline', - 'RetrieverConfig', - 'OptimizerConfig', - 'InferenceConfig', - 'PipelineConfig', - + "RAGPipeline", + "RetrieverConfig", + "OptimizerConfig", + "InferenceConfig", + "PipelineConfig", # Core components - 'ContextIndex', - 'IndexResult', - 'IntraContextOrderer', - 'ContextPilot', - + "ContextIndex", + "IndexResult", + "IntraContextOrderer", + "ContextPilot", + # Deduplication + "dedup_chat_completions", + "dedup_responses_api", + "DedupResult", # Convenience functions - 'optimize', - 'optimize_batch', - + "optimize", + "optimize_batch", # Retrievers - 'BM25Retriever', - 'FAISSRetriever', - 'FAISS_AVAILABLE', - 'Mem0Retriever', - 'create_mem0_corpus_map', - 'MEM0_AVAILABLE', + "BM25Retriever", + "FAISSRetriever", + "FAISS_AVAILABLE", + "Mem0Retriever", + "create_mem0_corpus_map", + "MEM0_AVAILABLE", ] diff --git a/contextpilot/dedup/__init__.py b/contextpilot/dedup/__init__.py new file mode 100644 index 0000000..77d3462 --- /dev/null +++ b/contextpilot/dedup/__init__.py @@ -0,0 +1,11 @@ +from .block_dedup import ( + dedup_chat_completions, + dedup_responses_api, + DedupResult, +) + +__all__ = [ + "dedup_chat_completions", + "dedup_responses_api", + "DedupResult", +] diff --git a/contextpilot/dedup/block_dedup.py b/contextpilot/dedup/block_dedup.py new file mode 100644 index 0000000..5d7c6b7 --- /dev/null +++ b/contextpilot/dedup/block_dedup.py @@ -0,0 +1,312 @@ +import hashlib +import logging +import re +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + +logger = logging.getLogger(__name__) + +MIN_BLOCK_CHARS = 80 +MIN_CONTENT_CHARS = 500 + +CHUNK_MODULUS = 13 +CHUNK_MIN_LINES = 5 +CHUNK_MAX_LINES = 40 + + +@dataclass +class DedupResult: + blocks_deduped: int = 0 + blocks_total: int = 0 + chars_before: int = 0 + chars_after: int = 0 + chars_saved: int = 0 + + +def _build_tool_name_map_openai(messages: list) -> Dict[str, str]: + mapping: Dict[str, str] = {} + for msg in messages: + if not isinstance(msg, dict) or msg.get("role") != "assistant": + continue + for tc in msg.get("tool_calls", []): + if not isinstance(tc, dict): + continue + tc_id = tc.get("id", "") + fn = tc.get("function", {}) + if isinstance(fn, dict) and fn.get("name"): + mapping[tc_id] = fn["name"] + return mapping + + +def _build_tool_name_map_responses(items: list) -> Dict[str, str]: + mapping: Dict[str, str] = {} + for item in items: + if isinstance(item, dict) and item.get("type") == "function_call": + cid = item.get("call_id", "") + name = item.get("name", "") + if cid and name: + mapping[cid] = name + return mapping + + +def _content_defined_chunking( + text: str, chunk_modulus: int = CHUNK_MODULUS +) -> List[str]: + lines = text.split("\n") + if len(lines) <= CHUNK_MIN_LINES: + return [text] + + blocks: List[str] = [] + current: List[str] = [] + + for line in lines: + current.append(line) + line_hash = int.from_bytes( + hashlib.md5(line.strip().encode("utf-8", errors="replace")).digest()[:4], + "little", + ) + is_boundary = ( + line_hash % chunk_modulus == 0 and len(current) >= CHUNK_MIN_LINES + ) or len(current) >= CHUNK_MAX_LINES + if is_boundary: + blocks.append("\n".join(current)) + current = [] + + if current: + if blocks and len(current) < CHUNK_MIN_LINES: + blocks[-1] += "\n" + "\n".join(current) + else: + blocks.append("\n".join(current)) + + return blocks + + +def _hash_block(block: str) -> str: + normalized = block.strip() + return hashlib.sha256(normalized.encode("utf-8", errors="replace")).hexdigest()[:20] + + +def _dedup_text( + text: str, + seen_blocks: Dict[str, Tuple[int, str, int]], + msg_idx: int, + fn_name: str, + result: DedupResult, + min_block_chars: int, + chunk_modulus: int, +) -> Optional[str]: + """Core dedup loop shared by all entry points. + + Returns the deduped text if any blocks were deduped, or None otherwise. + """ + blocks = _content_defined_chunking(text, chunk_modulus) + if len(blocks) < 2: + for b in blocks: + if len(b.strip()) >= min_block_chars: + h = _hash_block(b) + result.blocks_total += 1 + if h not in seen_blocks: + seen_blocks[h] = (msg_idx, fn_name, 0) + return None + + new_blocks = [] + deduped_count = 0 + + for block_idx, block in enumerate(blocks): + if len(block.strip()) < min_block_chars: + new_blocks.append(block) + continue + + h = _hash_block(block) + result.blocks_total += 1 + + if h in seen_blocks and seen_blocks[h][0] != msg_idx: + _, orig_fn, _ = seen_blocks[h] + first_line = block.strip().split("\n")[0][:80] + ref = f'[... "{first_line}" — identical to earlier {orig_fn} result, see above ...]' + chars_saved = len(block) - len(ref) + if chars_saved > 0: + new_blocks.append(ref) + deduped_count += 1 + result.blocks_deduped += 1 + else: + new_blocks.append(block) + else: + if h not in seen_blocks: + seen_blocks[h] = (msg_idx, fn_name, block_idx) + new_blocks.append(block) + + if deduped_count > 0: + return "\n".join(new_blocks) + + # Nothing deduped — register all blocks for future lookups + for block_idx, block in enumerate(blocks): + if len(block.strip()) >= min_block_chars: + h = _hash_block(block) + if h not in seen_blocks: + seen_blocks[h] = (msg_idx, fn_name, block_idx) + return None + + +def dedup_chat_completions( + body: dict, + min_block_chars: int = MIN_BLOCK_CHARS, + min_content_chars: int = MIN_CONTENT_CHARS, + chunk_modulus: int = CHUNK_MODULUS, +) -> DedupResult: + messages = body.get("messages") + if not isinstance(messages, list) or not messages: + return DedupResult() + + tool_names = _build_tool_name_map_openai(messages) + seen_blocks: Dict[str, Tuple[int, str, int]] = {} + result = DedupResult() + + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or msg.get("role") != "tool": + continue + + content = msg.get("content", "") + if not isinstance(content, str) or len(content) < min_content_chars: + continue + + tc_id = msg.get("tool_call_id", "") + fn_name = tool_names.get(tc_id, msg.get("name", "")) or "tool" + + new_content = _dedup_text( + content, seen_blocks, idx, fn_name, result, + min_block_chars, chunk_modulus, + ) + if new_content is not None: + original_len = len(content) + msg["content"] = new_content + new_len = len(new_content) + result.chars_before += original_len + result.chars_after += new_len + result.chars_saved += original_len - new_len + logger.info( + f"Block dedup: msg[{idx}] {fn_name} — " + f"saved {original_len - new_len:,} chars" + ) + + _dedup_assistant_code_blocks( + messages, seen_blocks, result, min_block_chars, min_content_chars, chunk_modulus + ) + + return result + + +_CODE_BLOCK_RE = re.compile(r"(```[\w]*\n)(.*?)(```)", re.DOTALL) + + +def _dedup_assistant_code_blocks( + messages: list, + seen_blocks: Dict[str, Tuple[int, str, int]], + result: DedupResult, + min_block_chars: int, + min_content_chars: int, + chunk_modulus: int, +) -> None: + for idx, msg in enumerate(messages): + if not isinstance(msg, dict) or msg.get("role") != "assistant": + continue + raw_content = msg.get("content", "") + + # Handle both string and list (content blocks) formats + is_list_content = False + text_block_idx = -1 + if isinstance(raw_content, str): + content = raw_content + elif isinstance(raw_content, list): + # OpenClaw sends [{type: "text", text: "..."}, ...] + # Find the text block that contains code + content = "" + for bi, block in enumerate(raw_content): + if isinstance(block, dict) and block.get("type") == "text": + t = block.get("text", "") + if "```" in t and len(t) > len(content): + content = t + text_block_idx = bi + is_list_content = True + if not content: + continue + else: + continue + + if len(content) < min_content_chars: + continue + + code_blocks = list(_CODE_BLOCK_RE.finditer(content)) + if not code_blocks: + continue + + modified = False + new_content = content + + for match in reversed(code_blocks): + code = match.group(2) + if len(code.strip()) < min_block_chars: + continue + + new_code = _dedup_text( + code, seen_blocks, idx, "assistant", result, + min_block_chars, chunk_modulus, + ) + if new_code is not None: + start, end = match.start(2), match.end(2) + original_len = end - start + new_content = new_content[:start] + new_code + new_content[end:] + result.chars_before += original_len + result.chars_after += len(new_code) + result.chars_saved += original_len - len(new_code) + modified = True + + if modified: + if is_list_content and text_block_idx >= 0: + msg["content"][text_block_idx]["text"] = new_content + else: + msg["content"] = new_content + + +def dedup_responses_api( + body: dict, + min_block_chars: int = MIN_BLOCK_CHARS, + min_content_chars: int = MIN_CONTENT_CHARS, + chunk_modulus: int = CHUNK_MODULUS, +) -> DedupResult: + input_items = body.get("input") + if not isinstance(input_items, list) or not input_items: + return DedupResult() + + fn_names = _build_tool_name_map_responses(input_items) + seen_blocks: Dict[str, Tuple[int, str, int]] = {} + result = DedupResult() + + for idx, item in enumerate(input_items): + if not isinstance(item, dict) or item.get("type") != "function_call_output": + continue + + output = item.get("output", "") + if not isinstance(output, str) or len(output) < min_content_chars: + continue + + call_id = item.get("call_id", "") + fn_name = fn_names.get(call_id, call_id) or "tool" + + new_output = _dedup_text( + output, seen_blocks, idx, fn_name, result, + min_block_chars, chunk_modulus, + ) + if new_output is not None: + original_len = len(output) + item["output"] = new_output + new_len = len(new_output) + result.chars_before += original_len + result.chars_after += new_len + result.chars_saved += original_len - new_len + logger.info( + f"Block dedup: input[{idx}] {fn_name} — " + f"saved {original_len - new_len:,} chars" + ) + + return result diff --git a/contextpilot/server/__init__.py b/contextpilot/server/__init__.py index e36633e..34146d5 100644 --- a/contextpilot/server/__init__.py +++ b/contextpilot/server/__init__.py @@ -8,12 +8,11 @@ """ from .metadata import NodeMetadata -from .eviction_heap import EvictionHeap from .live_index import ContextPilot -# HTTP client (optional - requires requests) try: from .http_client import ContextPilotIndexClient, evict_tokens + _HTTP_AVAILABLE = True except ImportError: _HTTP_AVAILABLE = False @@ -21,9 +20,8 @@ evict_tokens = None __all__ = [ - 'NodeMetadata', - 'EvictionHeap', - 'ContextPilot', - 'ContextPilotIndexClient', - 'evict_tokens', + "NodeMetadata", + "ContextPilot", + "ContextPilotIndexClient", + "evict_tokens", ] diff --git a/contextpilot/server/cloud_adapters/__init__.py b/contextpilot/server/cloud_adapters/__init__.py new file mode 100644 index 0000000..1c777a6 --- /dev/null +++ b/contextpilot/server/cloud_adapters/__init__.py @@ -0,0 +1,50 @@ +""" +Cloud Provider Adapters for ContextPilot Prompt Cache Proxy. + +Provides adapters for Anthropic, OpenAI, and MiniMax cloud LLM APIs, +handling API-specific auth, cache control injection, and response parsing. +""" + +from .base import CloudProviderAdapter, CacheMetrics, TTLTier +from .anthropic_adapter import AnthropicAdapter +from .openai_adapter import OpenAIAdapter +from .minimax_adapter import MiniMaxAdapter + + +_ADAPTERS = { + "anthropic": AnthropicAdapter, + "openai": OpenAIAdapter, + "minimax": MiniMaxAdapter, +} + + +def get_cloud_adapter(provider: str) -> CloudProviderAdapter: + """Factory: create adapter by provider name. + + Args: + provider: One of 'anthropic', 'openai', 'minimax' + + Returns: + CloudProviderAdapter instance + + Raises: + ValueError: If provider is not recognized + """ + cls = _ADAPTERS.get(provider) + if cls is None: + raise ValueError( + f"Unknown cloud provider: {provider!r}. " + f"Choose from: {list(_ADAPTERS.keys())}" + ) + return cls() + + +__all__ = [ + "CloudProviderAdapter", + "CacheMetrics", + "TTLTier", + "AnthropicAdapter", + "OpenAIAdapter", + "MiniMaxAdapter", + "get_cloud_adapter", +] diff --git a/contextpilot/server/cloud_adapters/anthropic_adapter.py b/contextpilot/server/cloud_adapters/anthropic_adapter.py new file mode 100644 index 0000000..da5722a --- /dev/null +++ b/contextpilot/server/cloud_adapters/anthropic_adapter.py @@ -0,0 +1,151 @@ +""" +Anthropic Cloud Provider Adapter. + +Handles Anthropic Messages API specifics: +- cache_control: {"type": "ephemeral"} injection on content blocks +- x-api-key authentication +- Cache metrics parsing from response.usage +""" + +import copy +import logging +from typing import Any, Dict, FrozenSet, List, Set + +from .base import CacheMetrics, CloudProviderAdapter, TTLTier + +logger = logging.getLogger(__name__) + +_ANTHROPIC_API_BASE = "https://api.anthropic.com" +_ANTHROPIC_VERSION = "2023-06-01" +_MIN_CONTENT_LENGTH_FOR_CACHE = 1024 + +_CACHE_CONTROL_DEFAULT = {"type": "ephemeral"} +_CACHE_CONTROL_EXTENDED = {"type": "ephemeral", "ttl": "1h"} + + +class AnthropicAdapter(CloudProviderAdapter): + """Adapter for Anthropic Messages API with prompt caching support.""" + + @property + def provider_name(self) -> str: + return "anthropic" + + def get_api_url(self, path: str = "") -> str: + return f"{_ANTHROPIC_API_BASE}{path}" + + def get_auth_headers(self, api_key: str) -> Dict[str, str]: + return { + "x-api-key": api_key, + "anthropic-version": _ANTHROPIC_VERSION, + "content-type": "application/json", + } + + def get_target_path(self) -> str: + return "/v1/messages" + + def get_default_ttl_seconds(self) -> int: + return 300 + + def get_extended_ttl_seconds(self): + return 3600 + + @property + def _cache_control_value(self) -> Dict[str, str]: + if self._configured_ttl == TTLTier.LONG: + return _CACHE_CONTROL_EXTENDED + return _CACHE_CONTROL_DEFAULT + + def inject_cache_control( + self, body: Dict[str, Any], cached_hashes: Set[str] + ) -> Dict[str, Any]: + body = copy.deepcopy(body) + cc = self._cache_control_value + body = _inject_system_cache_control(body, cc) + body = _inject_tool_result_cache_control(body, cc) + return body + + def parse_cache_metrics(self, response_body: Dict[str, Any]) -> CacheMetrics: + usage = response_body.get("usage", {}) + return CacheMetrics( + cache_creation_tokens=usage.get("cache_creation_input_tokens", 0), + cache_read_tokens=usage.get("cache_read_input_tokens", 0), + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + ) + + +# --------------------------------------------------------------------------- +# Helpers (shared with MiniMaxAdapter via import) +# --------------------------------------------------------------------------- + + +def _inject_system_cache_control( + body: Dict[str, Any], cc: Dict[str, str] +) -> Dict[str, Any]: + system = body.get("system") + if system is None: + return body + + if isinstance(system, str): + body["system"] = [{"type": "text", "text": system, "cache_control": cc}] + elif isinstance(system, list) and system: + last_block = system[-1] + if isinstance(last_block, dict): + last_block["cache_control"] = cc + return body + + +_MAX_TOOL_RESULT_BREAKPOINTS = 3 # Anthropic allows 4 total; 1 reserved for system + + +def _inject_tool_result_cache_control( + body: Dict[str, Any], cc: Dict[str, str] +) -> Dict[str, Any]: + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return body + + breakpoints_used = 0 + for msg in messages: + if breakpoints_used >= _MAX_TOOL_RESULT_BREAKPOINTS: + break + if msg.get("role") != "user": + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for block in content: + if breakpoints_used >= _MAX_TOOL_RESULT_BREAKPOINTS: + break + if not isinstance(block, dict): + continue + if block.get("type") not in ("tool_result", "toolResult"): + continue + _maybe_add_cache_control_to_tool_result(block, cc) + breakpoints_used += 1 + + return body + + +def _maybe_add_cache_control_to_tool_result( + block: Dict[str, Any], cc: Dict[str, str] +) -> None: + tr_content = block.get("content", "") + + if isinstance(tr_content, str): + if len(tr_content) >= _MIN_CONTENT_LENGTH_FOR_CACHE: + block["cache_control"] = cc + elif isinstance(tr_content, list): + total_chars = sum( + len(inner.get("text", "")) + for inner in tr_content + if isinstance(inner, dict) and inner.get("type") == "text" + ) + if total_chars >= _MIN_CONTENT_LENGTH_FOR_CACHE and tr_content: + last_text_block = None + for inner in reversed(tr_content): + if isinstance(inner, dict) and inner.get("type") == "text": + last_text_block = inner + break + if last_text_block is not None: + last_text_block["cache_control"] = cc diff --git a/contextpilot/server/cloud_adapters/base.py b/contextpilot/server/cloud_adapters/base.py new file mode 100644 index 0000000..2a0b12c --- /dev/null +++ b/contextpilot/server/cloud_adapters/base.py @@ -0,0 +1,88 @@ +""" +Base classes and shared types for cloud provider adapters. +""" + +from abc import ABC, abstractmethod +from typing import Dict, Any, FrozenSet, Optional, Set + +from contextpilot.server.ttl_eviction import TTLTier, CacheMetrics + + +class CloudProviderAdapter(ABC): + """Abstract base for cloud LLM API provider adapters. + + Each adapter handles provider-specific details: + - API URL construction + - Authentication headers + - Cache control annotation injection + - Response cache metrics parsing + """ + + def __init__(self): + self._configured_ttl: Optional[TTLTier] = None + + @property + def configured_ttl(self) -> Optional[TTLTier]: + return self._configured_ttl + + @configured_ttl.setter + def configured_ttl(self, value: TTLTier): + self._configured_ttl = value + + @property + @abstractmethod + def provider_name(self) -> str: + """Unique provider identifier (e.g. 'anthropic', 'openai').""" + ... + + @abstractmethod + def get_api_url(self, path: str = "") -> str: + """Build full API URL for the given path.""" + ... + + @abstractmethod + def get_auth_headers(self, api_key: str) -> Dict[str, str]: + """Build authentication headers for the provider.""" + ... + + @abstractmethod + def inject_cache_control( + self, body: Dict[str, Any], cached_hashes: Set[str] + ) -> Dict[str, Any]: + """Add provider-specific cache control annotations to the request body. + + Args: + body: Request body (will be deep-copied internally if modified) + cached_hashes: Set of content hashes currently in cache + + Returns: + Modified request body with cache control annotations + """ + ... + + @abstractmethod + def parse_cache_metrics(self, response_body: Dict[str, Any]) -> CacheMetrics: + """Extract cache usage metrics from the API response.""" + ... + + @abstractmethod + def get_default_ttl_seconds(self) -> int: + """Default local index TTL in seconds.""" + ... + + @abstractmethod + def get_extended_ttl_seconds(self) -> Optional[int]: + """Extended TTL in seconds, or None if not supported.""" + ... + + @property + def supports_extended_cache(self) -> bool: + return self.get_extended_ttl_seconds() is not None + + @abstractmethod + def get_target_path(self) -> str: + """Get the API endpoint path (e.g. '/v1/messages').""" + ... + + def __repr__(self): + return f"{self.__class__.__name__}(provider={self.provider_name!r})" diff --git a/contextpilot/server/cloud_adapters/minimax_adapter.py b/contextpilot/server/cloud_adapters/minimax_adapter.py new file mode 100644 index 0000000..03823b4 --- /dev/null +++ b/contextpilot/server/cloud_adapters/minimax_adapter.py @@ -0,0 +1,67 @@ +""" +MiniMax Cloud Provider Adapter. + +MiniMax provides an Anthropic-compatible API at api.minimax.io/anthropic. +Uses the same cache_control: {"type": "ephemeral"} format and response +metrics as Anthropic. +""" + +import copy +import logging +from typing import Any, Dict, FrozenSet, Set + +from .base import CacheMetrics, CloudProviderAdapter, TTLTier +from .anthropic_adapter import ( + _CACHE_CONTROL_DEFAULT, + _inject_system_cache_control, + _inject_tool_result_cache_control, +) + +logger = logging.getLogger(__name__) + +_MINIMAX_API_BASE = "https://api.minimax.io/anthropic" + + +class MiniMaxAdapter(CloudProviderAdapter): + """Adapter for MiniMax Anthropic-compatible API with prompt caching.""" + + @property + def provider_name(self) -> str: + return "minimax" + + def get_api_url(self, path: str = "") -> str: + return f"{_MINIMAX_API_BASE}{path}" + + def get_auth_headers(self, api_key: str) -> Dict[str, str]: + return { + "x-api-key": api_key, + "content-type": "application/json", + } + + def get_target_path(self) -> str: + return "/v1/messages" + + def get_default_ttl_seconds(self) -> int: + return 300 + + def get_extended_ttl_seconds(self): + return None + + def inject_cache_control( + self, body: Dict[str, Any], cached_hashes: Set[str] + ) -> Dict[str, Any]: + """Inject cache_control using Anthropic-compatible format.""" + body = copy.deepcopy(body) + body = _inject_system_cache_control(body, _CACHE_CONTROL_DEFAULT) + body = _inject_tool_result_cache_control(body, _CACHE_CONTROL_DEFAULT) + return body + + def parse_cache_metrics(self, response_body: Dict[str, Any]) -> CacheMetrics: + """Parse cache metrics — same format as Anthropic.""" + usage = response_body.get("usage", {}) + return CacheMetrics( + cache_creation_tokens=usage.get("cache_creation_input_tokens", 0), + cache_read_tokens=usage.get("cache_read_input_tokens", 0), + input_tokens=usage.get("input_tokens", 0), + output_tokens=usage.get("output_tokens", 0), + ) diff --git a/contextpilot/server/cloud_adapters/openai_adapter.py b/contextpilot/server/cloud_adapters/openai_adapter.py new file mode 100644 index 0000000..c93e001 --- /dev/null +++ b/contextpilot/server/cloud_adapters/openai_adapter.py @@ -0,0 +1,69 @@ +""" +OpenAI Cloud Provider Adapter. + +Only supports explicit prompt_cache_retention="24h" (extended caching). +In-memory caching (5-10 min, auto-adjusted) is NOT supported because the +TTL is non-deterministic. +""" + +import copy +import logging +from typing import Any, Dict, FrozenSet, Set + +from .base import CacheMetrics, CloudProviderAdapter, TTLTier + +logger = logging.getLogger(__name__) + +_OPENAI_API_BASE = "https://api.openai.com" + + +class OpenAIAdapter(CloudProviderAdapter): + @property + def provider_name(self) -> str: + return "openai" + + def get_api_url(self, path: str = "") -> str: + return f"{_OPENAI_API_BASE}{path}" + + def get_auth_headers(self, api_key: str) -> Dict[str, str]: + return { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + } + + def get_target_path(self) -> str: + return "/v1/chat/completions" + + def get_default_ttl_seconds(self) -> int: + return 3600 + + def get_extended_ttl_seconds(self): + return 86400 + + def inject_cache_control( + self, body: Dict[str, Any], cached_hashes: Set[str] + ) -> Dict[str, Any]: + if self.configured_ttl == TTLTier.LONG: + body = copy.deepcopy(body) + body["prompt_cache_retention"] = "24h" + return body + + def parse_cache_metrics(self, response_body: Dict[str, Any]) -> CacheMetrics: + usage = response_body.get("usage", {}) + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + + # OpenAI reports cached tokens in prompt_tokens_details + details = usage.get("prompt_tokens_details", {}) + cached_tokens = 0 + if isinstance(details, dict): + cached_tokens = details.get("cached_tokens", 0) + + return CacheMetrics( + cache_creation_tokens=max(0, prompt_tokens - cached_tokens) + if cached_tokens + else 0, + cache_read_tokens=cached_tokens, + input_tokens=prompt_tokens, + output_tokens=completion_tokens, + ) diff --git a/contextpilot/server/conversation_tracker.py b/contextpilot/server/conversation_tracker.py index 00ca9aa..4befb39 100644 --- a/contextpilot/server/conversation_tracker.py +++ b/contextpilot/server/conversation_tracker.py @@ -6,10 +6,10 @@ Usage: tracker = ConversationTracker() - + # Turn 1 req_a_id = tracker.register_request(docs=[4, 3, 1]) - + # Turn 2 (continuation of Turn 1) result = tracker.deduplicate( request_id=req_b_id, @@ -31,33 +31,22 @@ @dataclass class DeduplicationResult: - """Result of context deduplication.""" - - # Original documents in this turn original_docs: List[int] - - # Documents that overlap with previous turns (to be deduplicated) overlapping_docs: List[int] - - # New documents not seen in previous turns new_docs: List[int] - - # Reference hints for overlapping documents reference_hints: List[str] - - # The deduplicated context (new docs only, preserving order) deduplicated_docs: List[int] - - # Map: doc_id -> which turn it first appeared in doc_source_turns: Dict[int, str] = field(default_factory=dict) - - # Whether this is a new conversation (no parent, Turn 1) is_new_conversation: bool = False + blocks_deduped: int = 0 + blocks_total: int = 0 + block_chars_saved: int = 0 @dataclass class RequestHistory: """History of a single request.""" + request_id: str docs: List[int] parent_request_id: Optional[str] = None @@ -68,47 +57,49 @@ class RequestHistory: class ConversationTracker: """ Tracks conversation history for multi-turn context deduplication. - + Features: - Track documents sent per request - Track conversation chains (parent-child relationships) - Deduplicate contexts by removing already-seen documents - Generate reference hints for deduplicated documents """ - + def __init__(self, hint_template: str = None): """ Initialize the tracker. - + Args: hint_template: Template for reference hints. Default: "Please refer to [Doc {doc_id}] from turn {turn_number}." """ # request_id -> RequestHistory self._requests: Dict[str, RequestHistory] = {} - + # Template for generating reference hints - self._hint_template = hint_template or "Please refer to [Doc {doc_id}] from the previous conversation turn." - + self._hint_template = ( + hint_template + or "Please refer to [Doc {doc_id}] from the previous conversation turn." + ) + # Statistics self._stats = { - 'total_requests': 0, - 'total_dedup_calls': 0, - 'total_docs_deduplicated': 0, + "total_requests": 0, + "total_dedup_calls": 0, + "total_docs_deduplicated": 0, } - - def register_request(self, - request_id: str, - docs: List[int], - parent_request_id: Optional[str] = None) -> RequestHistory: + + def register_request( + self, request_id: str, docs: List[int], parent_request_id: Optional[str] = None + ) -> RequestHistory: """ Register a request and its documents. - + Args: request_id: Unique identifier for this request docs: List of document IDs sent in this request parent_request_id: ID of the previous turn's request (if multi-turn) - + Returns: RequestHistory object """ @@ -116,187 +107,250 @@ def register_request(self, turn_number = 1 if parent_request_id and parent_request_id in self._requests: turn_number = self._requests[parent_request_id].turn_number + 1 - + history = RequestHistory( request_id=request_id, docs=list(docs), parent_request_id=parent_request_id, - turn_number=turn_number + turn_number=turn_number, ) - + self._requests[request_id] = history - self._stats['total_requests'] += 1 - - logger.debug(f"Registered request {request_id}: {len(docs)} docs, turn {turn_number}") - + self._stats["total_requests"] += 1 + + logger.debug( + f"Registered request {request_id}: {len(docs)} docs, turn {turn_number}" + ) + return history - + def get_conversation_chain(self, request_id: str) -> List[RequestHistory]: """ Get the full conversation chain leading to this request. - + Args: request_id: The current request ID - + Returns: List of RequestHistory objects from first turn to current, in order """ chain = [] current_id = request_id - + visited = set() + while current_id and current_id in self._requests: + if current_id in visited: + logger.warning(f"Cycle detected in conversation chain at {current_id}") + break + visited.add(current_id) chain.append(self._requests[current_id]) current_id = self._requests[current_id].parent_request_id - + # Reverse to get chronological order chain.reverse() return chain - - def get_all_previous_docs(self, parent_request_id: str) -> Tuple[Set[int], Dict[int, str]]: + + def get_all_previous_docs( + self, parent_request_id: str + ) -> Tuple[Set[int], Dict[int, str]]: """ Get all documents from previous turns in the conversation. - + Args: parent_request_id: The parent request ID - + Returns: Tuple of (set of all doc IDs, dict mapping doc_id to request_id where it first appeared) """ all_docs = set() doc_sources = {} # doc_id -> request_id where it first appeared - + chain = self.get_conversation_chain(parent_request_id) - + for history in chain: for doc_id in history.docs: if doc_id not in all_docs: all_docs.add(doc_id) doc_sources[doc_id] = history.request_id - + return all_docs, doc_sources - - def deduplicate(self, - request_id: str, - docs: List[int], - parent_request_id: Optional[str] = None, - hint_template: Optional[str] = None) -> DeduplicationResult: - """ - Deduplicate documents for a new turn based on conversation history. - - Args: - request_id: ID for this request - docs: Documents retrieved for this turn - parent_request_id: ID of the previous turn's request - hint_template: Optional custom template for hints - - Returns: - DeduplicationResult with new docs, overlapping docs, and reference hints - """ - self._stats['total_dedup_calls'] += 1 - - # If no parent, this is turn 1 - no deduplication needed + + def deduplicate( + self, + request_id: str, + docs: Optional[List[int]] = None, + parent_request_id: Optional[str] = None, + hint_template: Optional[str] = None, + doc_contents: Optional[Dict[int, str]] = None, + ) -> DeduplicationResult: + if docs is None and doc_contents is not None: + docs = list(doc_contents.keys()) + elif docs is None: + docs = [] + self._stats["total_dedup_calls"] += 1 + if not parent_request_id or parent_request_id not in self._requests: - # Register this request self.register_request(request_id, docs, parent_request_id=None) - - return DeduplicationResult( + result = DeduplicationResult( original_docs=docs, overlapping_docs=[], new_docs=docs, reference_hints=[], deduplicated_docs=docs, doc_source_turns={}, - is_new_conversation=True + is_new_conversation=True, ) - - # Get all docs from previous turns + if doc_contents: + self._apply_block_dedup(doc_contents, result) + return result + previous_docs, doc_sources = self.get_all_previous_docs(parent_request_id) - - # Separate into overlapping and new + overlapping_docs = [] new_docs = [] doc_source_turns = {} - + for doc_id in docs: if doc_id in previous_docs: overlapping_docs.append(doc_id) doc_source_turns[doc_id] = doc_sources[doc_id] else: new_docs.append(doc_id) - - # Generate reference hints for overlapping docs + template = hint_template or self._hint_template reference_hints = [] - + for doc_id in overlapping_docs: source_request = doc_sources.get(doc_id) - source_history = self._requests.get(source_request) if source_request else None + source_history = ( + self._requests.get(source_request) if source_request else None + ) turn_number = source_history.turn_number if source_history else "previous" - hint = template.format( doc_id=doc_id, turn_number=turn_number, - source_request=source_request or "previous" + source_request=source_request or "previous", ) reference_hints.append(hint) - - # Register this request with only the new docs (for future deduplication) - # But store original docs for complete history + self.register_request(request_id, docs, parent_request_id) - - # Update stats - self._stats['total_docs_deduplicated'] += len(overlapping_docs) - - logger.info(f"Deduplication for {request_id}: " - f"{len(overlapping_docs)} overlapping, {len(new_docs)} new") - - return DeduplicationResult( + self._stats["total_docs_deduplicated"] += len(overlapping_docs) + + logger.info( + f"Deduplication for {request_id}: " + f"{len(overlapping_docs)} overlapping, {len(new_docs)} new" + ) + + result = DeduplicationResult( original_docs=docs, overlapping_docs=overlapping_docs, new_docs=new_docs, reference_hints=reference_hints, deduplicated_docs=new_docs, doc_source_turns=doc_source_turns, - is_new_conversation=False + is_new_conversation=False, ) - - def deduplicate_batch(self, - request_ids: List[str], - docs_list: List[List[int]], - parent_request_ids: Optional[List[Optional[str]]] = None, - hint_template: Optional[str] = None) -> List[DeduplicationResult]: - """ - Deduplicate multiple requests at once. - - Args: - request_ids: List of request IDs - docs_list: List of document lists, one per request - parent_request_ids: List of parent request IDs (None for turn 1) - hint_template: Optional custom template for hints - - Returns: - List of DeduplicationResult objects - """ + + if doc_contents: + self._apply_block_dedup(doc_contents, result) + + return result + + def _apply_block_dedup( + self, doc_contents: Dict[int, str], result: DeduplicationResult + ) -> None: + from contextpilot.dedup.block_dedup import ( + _content_defined_chunking, + _hash_block, + MIN_BLOCK_CHARS, + MIN_CONTENT_CHARS, + ) + + seen_blocks: Dict[str, int] = {} + + for doc_id in result.original_docs: + content = doc_contents.get(doc_id, "") + if len(content) < MIN_CONTENT_CHARS: + continue + + blocks = _content_defined_chunking(content) + if len(blocks) < 2: + for b in blocks: + if len(b.strip()) >= MIN_BLOCK_CHARS: + h = _hash_block(b) + if h not in seen_blocks: + seen_blocks[h] = doc_id + continue + + new_blocks = [] + deduped_count = 0 + + for block in blocks: + if len(block.strip()) < MIN_BLOCK_CHARS: + new_blocks.append(block) + continue + + h = _hash_block(block) + result.blocks_total += 1 + + if h in seen_blocks and seen_blocks[h] != doc_id: + first_line = block.strip().split("\n")[0][:80] + ref = f'[... "{first_line}" — identical to earlier result, see above ...]' + if len(block) > len(ref): + new_blocks.append(ref) + deduped_count += 1 + result.blocks_deduped += 1 + result.block_chars_saved += len(block) - len(ref) + else: + new_blocks.append(block) + else: + if h not in seen_blocks: + seen_blocks[h] = doc_id + new_blocks.append(block) + + if deduped_count > 0: + doc_contents[doc_id] = "\n".join(new_blocks) + + if result.blocks_deduped > 0: + logger.info( + f"Block dedup: {result.blocks_deduped}/{result.blocks_total} blocks, " + f"saved {result.block_chars_saved:,} chars" + ) + + def deduplicate_batch( + self, + request_ids: List[str], + docs_list: List[List[int]], + parent_request_ids: Optional[List[Optional[str]]] = None, + hint_template: Optional[str] = None, + doc_contents_list: Optional[List[Optional[Dict[int, str]]]] = None, + ) -> List[DeduplicationResult]: if parent_request_ids is None: parent_request_ids = [None] * len(request_ids) - + if doc_contents_list is None: + doc_contents_list = [None] * len(request_ids) + results = [] - for req_id, docs, parent_id in zip(request_ids, docs_list, parent_request_ids): - result = self.deduplicate(req_id, docs, parent_id, hint_template) + for req_id, docs, parent_id, doc_contents in zip( + request_ids, docs_list, parent_request_ids, doc_contents_list + ): + result = self.deduplicate( + req_id, docs, parent_id, hint_template, doc_contents=doc_contents + ) results.append(result) - + return results - + def remove_request(self, request_id: str) -> bool: """ Remove a request from tracking. - + Note: This will NOT update parent references of child requests. Use with caution. - + Args: request_id: The request to remove - + Returns: True if removed, False if not found """ @@ -304,43 +358,43 @@ def remove_request(self, request_id: str) -> bool: del self._requests[request_id] return True return False - + def clear_conversation(self, request_id: str) -> int: """ Clear all requests in a conversation chain. - + Args: request_id: Any request in the conversation - + Returns: Number of requests removed """ chain = self.get_conversation_chain(request_id) count = 0 - + for history in chain: if self.remove_request(history.request_id): count += 1 - + return count - + def reset(self): """Clear all tracked conversations.""" self._requests.clear() self._stats = { - 'total_requests': 0, - 'total_dedup_calls': 0, - 'total_docs_deduplicated': 0, + "total_requests": 0, + "total_dedup_calls": 0, + "total_docs_deduplicated": 0, } logger.info("ConversationTracker reset") - + def get_stats(self) -> Dict: """Get tracking statistics.""" return { **self._stats, - 'active_requests': len(self._requests), + "active_requests": len(self._requests), } - + def get_request_history(self, request_id: str) -> Optional[RequestHistory]: """Get history for a specific request.""" return self._requests.get(request_id) diff --git a/contextpilot/server/eviction_heap.py b/contextpilot/server/eviction_heap.py deleted file mode 100644 index 33da256..0000000 --- a/contextpilot/server/eviction_heap.py +++ /dev/null @@ -1,316 +0,0 @@ -""" -Eviction Heap for LRU Tracking - -Min-heap that tracks requests by last access time for efficient LRU eviction. -Mirrors SGLang's cache behavior with capacity-based eviction triggering. - -Key Design: -- Tracks by request_id (leaf nodes only have request_id) -- Capacity-based eviction when total tokens exceed max_tokens -- Synchronized with context index tree structure -""" - -import heapq -import time -from typing import Dict, Optional, List, Tuple, Set -from .metadata import NodeMetadata - - -class EvictionHeap: - """ - Min-heap for tracking least recently used requests. - - Uses last_access_time as the priority key. Only tracks leaf nodes - (which have request_id). Supports capacity-based eviction to mirror - SGLang's cache eviction behavior. - - Key invariants: - - Only leaf nodes (with request_id) are tracked - - The heap and context index remain synchronized - - Eviction removes tokens until below max_tokens capacity - """ - - def __init__(self, max_tokens: Optional[int] = None): - """ - Initialize eviction heap. - - Args: - max_tokens: Maximum token capacity (triggers eviction when exceeded) - """ - self._heap: List[Tuple[float, int]] = [] # (access_time, node_id) - self._metadata: Dict[int, NodeMetadata] = {} # node_id -> metadata - self._request_to_node: Dict[str, int] = {} # request_id -> node_id - self._in_heap: Dict[int, bool] = {} # Track which nodes are in heap - self._max_tokens = max_tokens - self._total_tokens = 0 - - @property - def max_tokens(self) -> Optional[int]: - """Get maximum token capacity.""" - return self._max_tokens - - @max_tokens.setter - def max_tokens(self, value: int): - """Set maximum token capacity.""" - self._max_tokens = value - - def push(self, metadata: NodeMetadata): - """ - Add a node to the heap. - - Only leaf nodes (with request_id) should be added. - Tracks extra_tokens (unique to this leaf) not total_tokens, - since shared prefix tokens are only stored once in the cache. - - Args: - metadata: Node metadata to track (must have request_id for leaf nodes) - """ - node_id = metadata.node_id - - if node_id in self._in_heap and self._in_heap[node_id]: - # Node already in heap - update access time and recalculate tokens - old_metadata = self._metadata.get(node_id) - if old_metadata: - # Adjust token count for the difference (use extra_tokens) - self._total_tokens += metadata.extra_tokens - old_metadata.extra_tokens - self._metadata[node_id] = metadata - self.update_access_time(node_id, metadata.last_access_time) - return - - heapq.heappush(self._heap, (metadata.last_access_time, node_id)) - self._metadata[node_id] = metadata - self._in_heap[node_id] = True - # Track extra_tokens (unique to this leaf), not total_tokens - self._total_tokens += metadata.extra_tokens - - # Track request_id -> node_id mapping (only for leaf nodes) - if metadata.request_id: - self._request_to_node[metadata.request_id] = node_id - - def pop(self) -> Optional[NodeMetadata]: - """ - Remove and return the least recently used node. - - Returns: - NodeMetadata of LRU node, or None if heap is empty - """ - while self._heap: - access_time, node_id = heapq.heappop(self._heap) - - # Skip if node was removed or is stale - if node_id not in self._metadata: - continue - - metadata = self._metadata[node_id] - - # Check if this is the current entry (not stale) - if metadata.last_access_time == access_time: - self._in_heap[node_id] = False - # Subtract extra_tokens when popping (unique tokens only) - self._total_tokens -= metadata.extra_tokens - return metadata - - # Stale entry, continue to next - - return None - - def peek(self) -> Optional[NodeMetadata]: - """ - View the least recently used node without removing. - - Returns: - NodeMetadata of LRU node, or None if heap is empty - """ - while self._heap: - access_time, node_id = self._heap[0] - - if node_id not in self._metadata: - heapq.heappop(self._heap) - continue - - metadata = self._metadata[node_id] - - if metadata.last_access_time == access_time: - return metadata - - # Stale entry, remove and continue - heapq.heappop(self._heap) - - return None - - def update_access_time(self, node_id: int, new_time: Optional[float] = None): - """ - Update a node's access time (lazy deletion approach). - - Args: - node_id: Node to update - new_time: New access time (defaults to current time) - """ - if node_id not in self._metadata: - return - - metadata = self._metadata[node_id] - - if new_time is None: - new_time = time.time() - - metadata.last_access_time = new_time - - # Push new entry (old one will be filtered as stale) - heapq.heappush(self._heap, (new_time, node_id)) - - def remove(self, node_id: int): - """ - Remove a node from tracking. - - Uses lazy deletion - actual removal happens during pop/peek. - - Args: - node_id: Node to remove - """ - if node_id in self._metadata: - metadata = self._metadata[node_id] - # Subtract extra_tokens (unique tokens only) - self._total_tokens -= metadata.extra_tokens - - # Remove request_id mapping - if metadata.request_id and metadata.request_id in self._request_to_node: - del self._request_to_node[metadata.request_id] - - del self._metadata[node_id] - - if node_id in self._in_heap: - self._in_heap[node_id] = False - - def get_node_by_request_id(self, request_id: str) -> Optional[NodeMetadata]: - """ - Get node metadata by request_id. - - Args: - request_id: The unique request identifier - - Returns: - NodeMetadata if found, None otherwise - """ - node_id = self._request_to_node.get(request_id) - if node_id is not None: - return self._metadata.get(node_id) - return None - - def update_tokens_for_request(self, request_id: str, input_tokens: int, output_tokens: int) -> bool: - """ - Accumulate tokens for a completed request. - - Called when a request completes and we need to track the total tokens - (input + output) for that request in the eviction heap. - - Args: - request_id: The unique request identifier - input_tokens: Number of input tokens (prompt) - output_tokens: Number of output tokens (generation) - - Returns: - True if successful, False if request_id not found - """ - metadata = self.get_node_by_request_id(request_id) - if metadata is None: - return False - - # Update total tokens - old_tokens = metadata.total_tokens - total_new = input_tokens + output_tokens - delta = total_new - old_tokens - - metadata.total_tokens = total_new - metadata.extra_tokens = max(0, metadata.extra_tokens + delta) - metadata.update_access_time() - - # Update heap tracking - self._total_tokens += delta - heapq.heappush(self._heap, (metadata.last_access_time, metadata.node_id)) - - return True - - def needs_eviction(self) -> bool: - """ - Check if eviction is needed based on capacity. - - Returns: - True if total_tokens exceeds max_tokens - """ - if self._max_tokens is None: - return False - return self._total_tokens > self._max_tokens - - def tokens_to_evict(self) -> int: - """ - Calculate how many tokens need to be evicted. - - Returns: - Number of tokens to evict to get below capacity - """ - if self._max_tokens is None or self._total_tokens <= self._max_tokens: - return 0 - return self._total_tokens - self._max_tokens - - def get_metadata(self, node_id: int) -> Optional[NodeMetadata]: - """Get metadata for a specific node.""" - return self._metadata.get(node_id) - - def is_empty(self) -> bool: - """Check if heap has any active nodes.""" - return self.peek() is None - - def size(self) -> int: - """Get number of active nodes in heap.""" - return len(self._metadata) - - def total_tokens(self) -> int: - """Get total tokens across all tracked nodes.""" - return self._total_tokens - - def get_all_request_ids(self) -> Set[str]: - """Get all tracked request IDs.""" - return set(self._request_to_node.keys()) - - def get_stats(self) -> Dict: - """ - Get statistics about the heap. - - Returns: - Dictionary with heap statistics - """ - if not self._metadata: - return { - 'size': 0, - 'total_tokens': 0, - 'max_tokens': self._max_tokens, - 'utilization_pct': 0, - 'avg_tokens_per_node': 0, - 'oldest_access_time': None, - 'newest_access_time': None, - 'num_requests': 0 - } - - access_times = [m.last_access_time for m in self._metadata.values()] - utilization = (self._total_tokens / self._max_tokens * 100) if self._max_tokens else 0 - - return { - 'size': len(self._metadata), - 'total_tokens': self._total_tokens, - 'max_tokens': self._max_tokens, - 'utilization_pct': utilization, - 'avg_tokens_per_node': self._total_tokens / len(self._metadata), - 'oldest_access_time': min(access_times), - 'newest_access_time': max(access_times), - 'num_requests': len(self._request_to_node) - } - - def __len__(self): - """Get number of active nodes.""" - return len(self._metadata) - - def __repr__(self): - return (f"EvictionHeap(size={len(self._metadata)}, " - f"total_tokens={self._total_tokens}, " - f"max_tokens={self._max_tokens})") diff --git a/contextpilot/server/http_server.py b/contextpilot/server/http_server.py index 6aca488..c71af9a 100644 --- a/contextpilot/server/http_server.py +++ b/contextpilot/server/http_server.py @@ -16,18 +16,21 @@ """ import argparse +import copy +import hashlib +import json import logging import time import asyncio import os import re import uuid -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, cast from contextlib import asynccontextmanager try: from fastapi import FastAPI, HTTPException, Request - from fastapi.responses import JSONResponse + from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, Field import uvicorn import aiohttp @@ -42,7 +45,19 @@ ConversationTracker, DeduplicationResult, get_conversation_tracker, - reset_conversation_tracker + reset_conversation_tracker, +) +from .intercept_parser import ( + parse_intercept_headers, + InterceptConfig, + get_format_handler, +) +from .ttl_eviction import TTLEvictionPolicy, TTLTier, CacheMetrics +from .cloud_adapters import get_cloud_adapter, CloudProviderAdapter +from contextpilot.dedup import ( + dedup_chat_completions, + dedup_responses_api, + DedupResult, ) @@ -65,13 +80,78 @@ _aiohttp_session: Optional[aiohttp.ClientSession] = None _tokenizer = None # AutoTokenizer instance for chat template _model_name: Optional[str] = None # Model name for tokenizer -_stateless_mode: bool = False # Stateless mode: just clustering/scheduling, no cache tracking +_stateless_mode: bool = ( + False # Stateless mode: just clustering/scheduling, no cache tracking +) +# Cloud proxy mode: forward to cloud LLM API with prompt cache optimization +_cloud_mode: bool = False +_chunk_modulus: int = 13 +_cloud_adapter: Optional[CloudProviderAdapter] = None +_cloud_api_key: Optional[str] = None +_ttl_policy: Optional[TTLEvictionPolicy] = None # Persistent string-to-ID mapping for string-input mode. # Same string always gets the same integer ID across /reorder calls. _str_to_id: Dict[str, int] = {} _id_to_str: Dict[int, str] = {} _next_str_id: int = 0 +# Persistent index for the intercept path. First request builds it +# (no reorder); subsequent requests use build_incremental to search +# the existing tree and reorder documents for prefix sharing. +_intercept_index: Optional[ContextPilot] = None + +# ── Conversation-aware intercept state ──────────────────────────────────── +# Tracks which tool results have already been processed, enabling +# skip-old / dedup-new / reorder-new behaviour. Single-conversation +# model (one user at a time). Resets when the system prompt changes. + +from dataclasses import dataclass, field as dc_field + + +@dataclass +class _InterceptConvState: + """Global intercept state for the current conversation.""" + + # Cached copy of the full messages array after modification (reorder/dedup). + # On subsequent turns, old messages are replaced with these cached versions + # so the inference engine's prefix cache sees identical tokens. + cached_messages: list = dc_field(default_factory=list) + # Cached system prompt (Anthropic format only) after modification. + cached_system: Any = None + # Whether the first tool result (reorder candidate) has been processed. + first_tool_result_done: bool = False + # Hashes of individual document strings seen across all tool results. + seen_doc_hashes: set = dc_field(default_factory=set) + # Hashes of single-doc tool_results (file reads, etc.) → tool_call_id. + # Used for cross-turn dedup of individual file reads like SKILL.md. + single_doc_hashes: dict = dc_field(default_factory=dict) + # Whether the system prompt has been processed (reordered) already. + system_processed: bool = False + # Number of messages in the last request. Messages only grow in a + # multi-turn conversation; if the count drops, it's a new session. + last_message_count: int = 0 + + +_intercept_state = _InterceptConvState() + +# TTFT tracking for averages across a session +_ttft_history: List[float] = [] +_ttft_chars_saved_total = 0 + + +def _log_ttft(ttft_ms: float, slimmed: int, chars_saved: int) -> None: + global _ttft_chars_saved_total + _ttft_history.append(ttft_ms) + _ttft_chars_saved_total += chars_saved + avg = sum(_ttft_history) / len(_ttft_history) + logger.info( + f"TTFT: {ttft_ms:.0f}ms " + f"(avg {avg:.0f}ms over {len(_ttft_history)} reqs, " + f"slimmed {slimmed}, saved {chars_saved:,} chars, " + f"total saved {_ttft_chars_saved_total:,} chars)" + ) + + # Request ID normalization (engine -> ContextPilot canonical IDs) _ENGINE_REQ_ID_PREFIX = re.compile(r"^(cmpl-|chatcmpl-|batch-)") _VLLM_REQ_SUFFIX = re.compile(r"^(req-[^-]+)-\d+-[0-9a-f]+$") @@ -89,18 +169,53 @@ def _normalize_request_id(request_id: str) -> str: def _init_config(): """Initialize config from environment variables.""" global _max_tokens, _infer_api_url, _tokenizer, _model_name, _stateless_mode + global _cloud_mode, _cloud_adapter, _cloud_api_key, _ttl_policy # Check stateless mode first env_stateless = os.environ.get("CONTEXTPILOT_STATELESS_MODE", "0") _stateless_mode = env_stateless == "1" + # Check cloud proxy mode + cloud_provider = os.environ.get("CONTEXTPILOT_CLOUD_PROVIDER") + if cloud_provider and _cloud_adapter is None: + _cloud_mode = True + _cloud_adapter = get_cloud_adapter(cloud_provider) + _cloud_api_key = os.environ.get("CONTEXTPILOT_CLOUD_API_KEY", "") + extended = os.environ.get("CONTEXTPILOT_EXTENDED_CACHE") == "1" + if extended: + ext_seconds = _cloud_adapter.get_extended_ttl_seconds() + if ext_seconds is None: + logger.warning( + f"{cloud_provider} does not support --extended-cache, ignoring" + ) + ttl_seconds = _cloud_adapter.get_default_ttl_seconds() + else: + ttl_seconds = ext_seconds + _cloud_adapter.configured_ttl = TTLTier.LONG + else: + ttl_seconds = _cloud_adapter.get_default_ttl_seconds() + _ttl_policy = TTLEvictionPolicy( + default_ttl_seconds=ttl_seconds, + ) + logger.info( + f"Cloud proxy mode: provider={cloud_provider}, " + f"index_ttl={ttl_seconds}s" + + ( + " (extended)" + if extended and _cloud_adapter.supports_extended_cache + else "" + ) + ) + if _max_tokens is None and not _stateless_mode: env_max_tokens = os.environ.get("CONTEXTPILOT_MAX_TOKENS") if env_max_tokens: _max_tokens = int(env_max_tokens) if _infer_api_url is None: - _infer_api_url = os.environ.get("CONTEXTPILOT_INFER_API_URL", "http://localhost:30000") + _infer_api_url = os.environ.get( + "CONTEXTPILOT_INFER_API_URL", "http://localhost:30000" + ) # Initialize tokenizer for chat template if model is specified if _tokenizer is None: @@ -129,17 +244,16 @@ class BuildIndexRequest(BaseModel): linkage_method: str = Field("average", description="Linkage method for clustering") # Multi-turn deduplication fields parent_request_ids: Optional[List[Optional[str]]] = Field( - None, + None, description="List of parent request IDs for multi-turn deduplication. " - "Each element corresponds to a context. None means turn 1 (no parent)." + "Each element corresponds to a context. None means turn 1 (no parent).", ) deduplicate: bool = Field( - False, - description="If True, deduplicate contexts based on conversation history" + False, description="If True, deduplicate contexts based on conversation history" ) hint_template: Optional[str] = Field( None, - description="Template for reference hints. Use {doc_id} and {turn_number} placeholders." + description="Template for reference hints. Use {doc_id} and {turn_number} placeholders.", ) @@ -147,9 +261,9 @@ class ScheduleRequest(BaseModel): """Request to schedule a batch (legacy, use ReorderRequest instead).""" contexts: List[List[Any]] = Field( - ..., + ..., description="List of contexts. Each context is a list of items (int doc IDs OR string doc contents). " - "If strings are provided, identical strings are treated as the same document." + "If strings are provided, identical strings are treated as the same document.", ) alpha: float = Field(0.001, description="Distance computation parameter") use_gpu: bool = Field(False, description="Use GPU for distance computation") @@ -162,7 +276,7 @@ class ReorderRequest(BaseModel): contexts: List[List[Any]] = Field( ..., description="List of contexts. Each context is a list of items (int doc IDs OR string doc contents). " - "If strings are provided, identical strings are treated as the same document." + "If strings are provided, identical strings are treated as the same document.", ) alpha: float = Field(0.001, description="Distance computation parameter") use_gpu: bool = Field(False, description="Use GPU for distance computation") @@ -173,23 +287,23 @@ class ReorderRequest(BaseModel): ) parent_request_ids: Optional[List[Optional[str]]] = Field( None, - description="Parent request IDs for multi-turn deduplication (stateful mode only)" + description="Parent request IDs for multi-turn deduplication (stateful mode only)", ) deduplicate: bool = Field( False, - description="If True, deduplicate contexts based on conversation history (stateful mode only)" + description="If True, deduplicate contexts based on conversation history (stateful mode only)", ) hint_template: Optional[str] = Field( - None, - description="Template for reference hints (stateful mode only)" + None, description="Template for reference hints (stateful mode only)" ) class EvictRequest(BaseModel): """Request to evict (remove) requests from the index.""" - request_ids: List[str] = Field(..., description="List of request IDs to evict/remove") - + request_ids: List[str] = Field( + ..., description="List of request IDs to evict/remove" + ) class SearchRequest(BaseModel): @@ -209,18 +323,18 @@ class InsertRequest(BaseModel): class DeduplicateRequest(BaseModel): """Request to deduplicate contexts for multi-turn conversations.""" - + contexts: List[List[Any]] = Field( ..., description="List of contexts (each is a list of document IDs)" ) parent_request_ids: List[Optional[str]] = Field( - ..., + ..., description="List of parent request IDs. Each element corresponds to a context. " - "None means turn 1 (no parent, will be registered for future dedup)." + "None means turn 1 (no parent, will be registered for future dedup).", ) hint_template: Optional[str] = Field( None, - description="Template for reference hints. Use {doc_id} and {turn_number} placeholders." + description="Template for reference hints. Use {doc_id} and {turn_number} placeholders.", ) @@ -309,6 +423,34 @@ async def health(): } +@app.get("/metrics/ttft") +async def metrics_ttft(last: int = 0): + """Return TTFT history for benchmarking. + + Query params: + last: return only last N entries (0 = all). + """ + history = list(_ttft_history) + if last > 0: + history = history[-last:] + avg = sum(history) / len(history) if history else 0 + return { + "ttft_ms": history, + "count": len(history), + "avg_ms": round(avg, 2), + "total_chars_saved": _ttft_chars_saved_total, + } + + +@app.post("/metrics/ttft/reset") +async def metrics_ttft_reset(): + """Reset TTFT history (call before a benchmark run).""" + global _ttft_chars_saved_total + _ttft_history.clear() + _ttft_chars_saved_total = 0 + return {"status": "ok"} + + @app.post("/reorder") async def reorder(request: ReorderRequest): """ @@ -349,6 +491,7 @@ async def reorder(request: ReorderRequest): # ── internal helpers ───────────────────────────────────────────────────────── + async def _reorder_stateless(request: ReorderRequest): """Stateless reorder: one-shot clustering + scheduling, no state.""" try: @@ -391,8 +534,7 @@ async def _reorder_stateless(request: ReorderRequest): scheduled_contexts = result["reordered_contexts"] if is_string_input: scheduled_contexts = [ - [id_to_str[item_id] for item_id in ctx] - for ctx in scheduled_contexts + [id_to_str[item_id] for item_id in ctx] for ctx in scheduled_contexts ] logger.info( @@ -469,27 +611,40 @@ def _to_output(reordered): dedup_results = None if request.deduplicate: tracker = get_conversation_tracker() + docs_list = result.get("reordered_contexts") or contexts + doc_contents_list = None + if _id_to_str: + doc_contents_list = [ + {did: _id_to_str[did] for did in ctx if did in _id_to_str} + for ctx in docs_list + ] dedup_results = tracker.deduplicate_batch( - request_ids=result['request_ids'], - docs_list=result.get('reordered_contexts') or contexts, + request_ids=result["request_ids"], + docs_list=docs_list, parent_request_ids=request.parent_request_ids, hint_template=request.hint_template, + doc_contents_list=doc_contents_list, ) + if doc_contents_list: + for dc in doc_contents_list: + for did, content in dc.items(): + if did in _id_to_str and content != _id_to_str[did]: + _id_to_str[did] = content logger.info(f"Deduplication: processed {len(dedup_results)} contexts") - reordered = _to_output(result.get('reordered_contexts')) + reordered = _to_output(result.get("reordered_contexts")) response = { "status": "success", "message": "Incremental reorder completed", "mode": "incremental", "input_type": "string" if is_string_input else "integer", "num_contexts": len(contexts), - "matched_count": result['matched_count'], - "merged_count": result['merged_count'], - "request_ids": result['request_ids'], + "matched_count": result["matched_count"], + "merged_count": result["merged_count"], + "request_ids": result["request_ids"], "reordered_contexts": reordered, - "original_indices": result['original_indices'], - "groups": result['groups'], + "original_indices": result["original_indices"], + "groups": result["groups"], "stats": _index.get_stats(), } @@ -498,16 +653,24 @@ def _to_output(reordered): "enabled": True, "results": [ { - "request_id": result['request_ids'][i], + "request_id": result["request_ids"][i], "original_docs": r.original_docs, "deduplicated_docs": r.deduplicated_docs, "overlapping_docs": r.overlapping_docs, "new_docs": r.new_docs, "reference_hints": r.reference_hints, + "blocks_deduped": r.blocks_deduped, + "blocks_total": r.blocks_total, + "block_chars_saved": r.block_chars_saved, } for i, r in enumerate(dedup_results) ], - "total_docs_deduplicated": sum(len(r.overlapping_docs) for r in dedup_results), + "total_docs_deduplicated": sum( + len(r.overlapping_docs) for r in dedup_results + ), + "total_blocks_deduped": sum( + r.blocks_deduped for r in dedup_results + ), } return response @@ -529,20 +692,30 @@ def _to_output(reordered): request_id_mapping = result.get("request_id_mapping", {}) request_ids = result.get("request_ids", []) - logger.info( - f"Index built. Auto-assigned {len(request_id_mapping)} request IDs" - ) + logger.info(f"Index built. Auto-assigned {len(request_id_mapping)} request IDs") dedup_results = None if request.deduplicate: tracker = get_conversation_tracker() - reordered_raw = result.get('reordered_contexts') or contexts + reordered_raw = result.get("reordered_contexts") or contexts + doc_contents_list = None + if _id_to_str: + doc_contents_list = [ + {did: _id_to_str[did] for did in ctx if did in _id_to_str} + for ctx in reordered_raw + ] dedup_results = tracker.deduplicate_batch( request_ids=request_ids, docs_list=reordered_raw, parent_request_ids=request.parent_request_ids, hint_template=request.hint_template, + doc_contents_list=doc_contents_list, ) + if doc_contents_list: + for dc in doc_contents_list: + for did, content in dc.items(): + if did in _id_to_str and content != _id_to_str[did]: + _id_to_str[did] = content logger.info(f"Deduplication: processed {len(dedup_results)} contexts") reordered = _to_output(result.get("reordered_contexts", contexts)) @@ -573,10 +746,16 @@ def _to_output(reordered): "overlapping_docs": r.overlapping_docs, "new_docs": r.new_docs, "reference_hints": r.reference_hints, + "blocks_deduped": r.blocks_deduped, + "blocks_total": r.blocks_total, + "block_chars_saved": r.block_chars_saved, } for i, r in enumerate(dedup_results) ], - "total_docs_deduplicated": sum(len(r.overlapping_docs) for r in dedup_results), + "total_docs_deduplicated": sum( + len(r.overlapping_docs) for r in dedup_results + ), + "total_blocks_deduped": sum(r.blocks_deduped for r in dedup_results), } return response @@ -588,6 +767,7 @@ def _to_output(reordered): # ── Legacy aliases (deprecated, use /reorder) ──────────────────────────────── + @app.post("/build", deprecated=True) async def build_index(request: BuildIndexRequest): """Deprecated — use POST /reorder instead. Kept for backward compatibility.""" @@ -612,10 +792,15 @@ async def schedule_batch(request: ScheduleRequest): alpha=request.alpha, use_gpu=request.use_gpu, linkage_method=request.linkage_method, + initial_tokens_per_context=0, + parent_request_ids=None, + deduplicate=False, + hint_template=None, ) # Force stateless behaviour regardless of server mode return await _reorder_stateless(unified) + @app.post("/evict") async def evict(request: EvictRequest): """ @@ -642,20 +827,16 @@ async def evict(request: EvictRequest): try: logger.debug(f"Eviction incoming IDs: {request.request_ids}") + normalized_ids = [_normalize_request_id(rid) for rid in request.request_ids] normalized_ids = [ - _normalize_request_id(rid) - for rid in request.request_ids - ] - normalized_ids = [ - rid for rid in normalized_ids - if rid and not rid.startswith("HEALTH_CHECK") + rid for rid in normalized_ids if rid and not rid.startswith("HEALTH_CHECK") ] # Deduplicate while preserving order for deterministic logs/responses. normalized_ids = list(dict.fromkeys(normalized_ids)) # Remove the evicted requests from our index - result = _index.remove_requests(normalized_ids) - + result = _index.remove_requests(set(normalized_ids)) + # Also clear conversation history for evicted requests # This ensures ConversationTracker stays in sync with the engine's cache tracker = get_conversation_tracker() @@ -688,40 +869,52 @@ async def evict(request: EvictRequest): async def reset_index(): """ Reset the index to initial state. - + Clears all nodes, metadata, request tracking, conversation history, and string-to-ID mappings. Use this to start fresh without restarting the server. - + After reset, you must call /reorder again before other operations. """ - global _index, _str_to_id, _id_to_str, _next_str_id + global \ + _index, \ + _str_to_id, \ + _id_to_str, \ + _next_str_id, \ + _intercept_index, \ + _intercept_state # Reset conversation tracker reset_conversation_tracker() + # Reset intercept conversation state + _intercept_state = _InterceptConvState() + _intercept_index = None + # Reset string-to-ID mapping _str_to_id = {} _id_to_str = {} _next_str_id = 0 - + if _index is None: return { "status": "success", "message": "No index to reset (was not initialized)", "conversation_tracker": "reset", } - + try: _index.reset() - logger.info("Index, conversation tracker, and string mappings reset successfully") - + logger.info( + "Index, conversation tracker, and string mappings reset successfully" + ) + return { "status": "success", "message": "Index reset to initial state", "conversation_tracker": "reset", } - + except Exception as e: logger.error(f"Error resetting index: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -904,10 +1097,14 @@ async def proxy_completions(request: Request): detail="Inference API URL not configured. Set CONTEXTPILOT_INFER_API_URL env var or use --infer-api-url.", ) + session = _aiohttp_session + if session is None: + raise HTTPException(status_code=503, detail="HTTP session not initialized") + try: # Parse request body body = await request.json() - + # Check for request_id (from manual calls) or rid (from RAGPipeline) # RAGPipeline sends 'rid' directly, manual calls may use 'request_id' request_id = body.pop("request_id", None) or body.get("rid", None) @@ -932,40 +1129,764 @@ async def proxy_completions(request: Request): # Pass request_id to inference engine so it can use the same ID for request tracking # Engine will notify ContextPilot via /evict callback when this request is evicted if request_id: - body["rid"] = request_id # SGLang - body["request_id"] = request_id # vLLM + body["rid"] = request_id # SGLang + body["request_id"] = request_id # vLLM logger.info(f"Proxy: forwarding request with request_id={request_id}") else: - logger.info("Proxy: forwarding request without rid (no ContextPilot tracking)") + logger.info( + "Proxy: forwarding request without rid (no ContextPilot tracking)" + ) # Forward to inference engine api_url = f"{infer_api_url}/v1/completions" logger.debug(f"Proxying to {api_url}") - async with _aiohttp_session.post(api_url, json=body) as response: + async with session.post(api_url, json=body) as response: result = await response.json() # Token tracking is handled by the inference engine via CONTEXTPILOT_INDEX_URL # The engine calls /evict after its internal cache eviction - - # Add request_id to response for client reference + + # Add request_id to response header (not body, to avoid + # breaking strict API response parsers). + cp_headers = {} if request_id and response.status == 200: usage = result.get("usage", {}) - result["_contextpilot"] = { - "request_id": request_id, - "tokens_reported": usage.get("total_tokens", 0), - } + cp_headers["X-ContextPilot-Result"] = json.dumps( + { + "request_id": request_id, + "tokens_reported": usage.get("total_tokens", 0), + } + ) - return JSONResponse(content=result, status_code=response.status) + return JSONResponse( + content=result, status_code=response.status, headers=cp_headers + ) except aiohttp.ClientError as e: logger.error(f"Error proxying to inference engine: {e}") - raise HTTPException(status_code=502, detail=f"Inference engine backend error: {str(e)}") + raise HTTPException( + status_code=502, detail=f"Inference engine backend error: {str(e)}" + ) except Exception as e: logger.error(f"Error in proxy: {e}") raise HTTPException(status_code=500, detail=str(e)) +# ============================================================================ +# HTTP Intercept Proxy Endpoints +# ============================================================================ + + +# ── Conversation-aware helpers ───────────────────────────────────────────── + + +def _hash_text(text: str) -> str: + """Fast 16-hex-char hash for content comparison.""" + return hashlib.sha256(text.encode("utf-8", errors="replace")).hexdigest()[:16] + + +def _get_intercept_state(body: Dict[str, Any]) -> _InterceptConvState: + """Return the global intercept state, resetting if the conversation changed. + + Detection: in a multi-turn agent conversation the messages array only + grows. If the count drops, either a new session started or the host + compacted old messages. Either way, reset all state: the old KV cache + entries are gone (compaction rewrites content), so cached_messages, + seen_doc_hashes, and reorder state are all invalid. + """ + global _intercept_state + msg_count = len(body.get("messages") or []) + if msg_count < _intercept_state.last_message_count: + logger.info( + f"Intercept: message count dropped " + f"({msg_count} < {_intercept_state.last_message_count}), " + f"resetting all state (compaction or new session)" + ) + _intercept_state = _InterceptConvState() + # Skip reorder for the first post-compaction tool result: + # prefix cache is fully invalidated, nothing to align with. + # Go straight to dedup mode so docs are registered for future turns. + _intercept_state.first_tool_result_done = True + _intercept_state.system_processed = True + _intercept_state.last_message_count = msg_count + return _intercept_state + + +def _deduplicate_docs(docs: List[str], state: _InterceptConvState) -> tuple: + """Remove documents already seen in previous tool results. + + Returns (new_docs, deduped_count). Also registers all doc hashes + (including duplicates) in state so future calls can dedup against them. + """ + new_docs = [] + deduped_count = 0 + for doc in docs: + h = _hash_text(doc) + if h in state.seen_doc_hashes: + deduped_count += 1 + else: + new_docs.append(doc) + state.seen_doc_hashes.add(h) + return new_docs, deduped_count + + +# Regex for OpenClaw's EXTERNAL_UNTRUSTED_CONTENT security markers. +# These contain a random hex id that changes every request, preventing +# KV cache prefix sharing for identical content. +_EXTERNAL_MARKER_RE = re.compile( + r'<<<((?:END_)?)EXTERNAL_UNTRUSTED_CONTENT\s+id=\\?"[0-9a-f]+\\?">>>' +) + + +def _strip_external_content_ids(body: Any) -> Any: + """Remove random ids from EXTERNAL_UNTRUSTED_CONTENT markers in the body. + + Walks the body dict/list and applies the regex on every string value, + turning ``<<>>`` into + ``<<>>``. + """ + if isinstance(body, str): + return _EXTERNAL_MARKER_RE.sub( + lambda m: f"<<<{m.group(1) or ''}EXTERNAL_UNTRUSTED_CONTENT>>>", body + ) + if isinstance(body, dict): + return {k: _strip_external_content_ids(v) for k, v in body.items()} + if isinstance(body, list): + return [_strip_external_content_ids(v) for v in body] + return body + + +# API format constants +_OPENAI_CHAT = "openai_chat" +_ANTHROPIC_MESSAGES = "anthropic_messages" + + +def _doc_preview(doc: str, max_len: int = 60) -> str: + """Truncate a document string for log preview.""" + doc = doc.replace("\n", " ").strip() + return doc[:max_len] + "…" if len(doc) > max_len else doc + + +def _reorder_documents(docs: List[str], config: InterceptConfig) -> tuple: + """Reorder a list of document strings via the persistent intercept index. + + All documents form ONE context (a single list of doc strings). + The first call (``_intercept_index is None``) builds the index — + since there is only one context the order stays unchanged. + Subsequent calls use ``build_incremental`` which searches the + existing tree and reorders documents for prefix sharing with + previously cached state. + + Returns (reordered_docs, original_order, reordered_order) where the + order lists are 0-based indices suitable for logging/headers. + """ + global _intercept_index + + contexts = [docs] # All docs = one context + original_order = list(range(len(docs))) + + if _intercept_index is None: + # First call — build index only. 1 context → no reorder possible. + _intercept_index = ContextPilot( + alpha=config.alpha, + use_gpu=False, + linkage_method=config.linkage_method, + ) + _intercept_index.build_and_schedule(contexts=cast(List[List[int]], contexts)) + logger.debug("Intercept index initialised (no reorder on first call)") + return docs, original_order, original_order + + # Subsequent calls — search existing tree and reorder for prefix sharing. + # build_incremental returns reordered_contexts with strings converted. + result = _intercept_index.build_incremental( + contexts=cast(List[List[int]], contexts) + ) + reordered = result.get("reordered_contexts", [docs])[0] + + # Build order mapping: find where each reordered doc was in the original. + doc_to_orig = {} + for i, doc in enumerate(docs): + doc_to_orig.setdefault(doc, []).append(i) + reordered_order = [] + used = set() + for doc in reordered: + for idx in doc_to_orig.get(doc, []): + if idx not in used: + reordered_order.append(idx) + used.add(idx) + break + reordered_docs = [docs[i] for i in reordered_order] + + if logger.isEnabledFor(logging.DEBUG): + for label, order in [("BEFORE", original_order), ("AFTER", reordered_order)]: + previews = [f" [{i}] {_doc_preview(docs[i])}" for i in order] + logger.debug(f"Reorder {label}:\n" + "\n".join(previews)) + + return reordered_docs, original_order, reordered_order + + +async def _intercept_and_forward(request: Request, api_format: str): + """Intercept an LLM API request, reorder documents, and forward. + + 1. Parse X-ContextPilot-* headers → InterceptConfig + 2. Extract documents from system message/prompt and tool_results + 3. Reorder each extraction via ContextPilot clustering + 4. Reconstruct request body with reordered docs + 5. Forward to actual LLM backend, streaming or not + If extraction fails at any step → forward original request unmodified. + """ + if _infer_api_url is None: + _init_config() + + infer_api_url = _infer_api_url or os.environ.get( + "CONTEXTPILOT_INFER_API_URL", "http://localhost:30000" + ) + if not infer_api_url: + raise HTTPException( + status_code=503, + detail="Inference API URL not configured.", + ) + + session = _aiohttp_session + if session is None: + raise HTTPException(status_code=503, detail="HTTP session not initialized") + + try: + body = await request.json() + except Exception: + raise HTTPException(status_code=400, detail="Invalid JSON body") + + # Strip random IDs from OpenClaw's EXTERNAL_UNTRUSTED_CONTENT markers + # early, so extraction/clustering sees deterministic content and + # identical documents share the same KV cache prefix. + body = _strip_external_content_ids(body) + + # Parse intercept config from headers + headers = dict(request.headers) + config = parse_intercept_headers(headers) + total_reordered = 0 + total_deduped = 0 + total_slimmed = 0 + tool_results_skipped = 0 # TODO: never incremented — wire up or remove + _chars_before_slim = 0 + _chars_after_slim = 0 + system_count = 0 + tool_result_count = 0 + reorder_details = [] # collect per-source reorder info + _dedup_result = DedupResult() + state = _intercept_state + + # ── Debug: log conversation shape, divergence, and tool_result details ── + _debug_messages = body.get("messages") or [] + _debug_msg_count = len(_debug_messages) + + # Per-message hashes for this request + _debug_msg_hashes = [] + if logger.isEnabledFor(logging.DEBUG): + for m in _debug_messages: + h = hashlib.sha256( + json.dumps(m, sort_keys=True, ensure_ascii=False).encode() + ).hexdigest()[:12] + _debug_msg_hashes.append(h) + + # Build tool_call_id → function name mapping from assistant messages + _tool_call_names = {} + for m in _debug_messages: + if m.get("role") == "assistant": + for tc in m.get("tool_calls") or []: + _tc_id = tc.get("id", "") + _fn = (tc.get("function") or {}).get("name", "?") + _args_raw = (tc.get("function") or {}).get("arguments", "") + # Extract file path for read calls + _path_hint = "" + if _fn == "read" and isinstance(_args_raw, str): + try: + _args = json.loads(_args_raw) + _p = _args.get("path") or _args.get("file_path") or "" + if _p: + _path_hint = f" path={_p.split('/')[-1]}" + except Exception: + pass + _tool_call_names[_tc_id] = f"{_fn}{_path_hint}" + + # Log all tool_result messages with size, function name, and content preview + for idx, m in enumerate(_debug_messages): + _role = m.get("role", "?") + if _role in ("tool", "toolResult"): + _tc_id = m.get("tool_call_id", "?") + _fn_label = _tool_call_names.get(_tc_id, "?") + _content = m.get("content", "") + _content_str = str(_content) + _chars = len(_content_str) + _is_compacted = "[compacted:" in _content_str + _preview = _content_str[:150].replace("\n", "\\n") + logger.info( + f" msg[{idx}] role={_role} fn={_fn_label} " + f"tool_call_id={_tc_id} " + f"chars={_chars} compacted={_is_compacted} " + f"preview: {_preview}" + ) + elif _role == "user" and isinstance(m.get("content"), list): + for bi, block in enumerate(m["content"]): + if isinstance(block, dict) and block.get("type") in ( + "tool_result", + "toolResult", + ): + _tu_id = block.get("tool_use_id", "?") + _tc = block.get("content", "") + _tc_str = str(_tc) + _chars = len(_tc_str) + _is_compacted = "[compacted:" in _tc_str + _preview = _tc_str[:150].replace("\n", "\\n") + logger.info( + f" msg[{idx}].content[{bi}] type=tool_result " + f"tool_use_id={_tu_id} chars={_chars} " + f"compacted={_is_compacted} preview: {_preview}" + ) + + global _debug_prev_msg_hashes + if "_debug_prev_msg_hashes" not in globals(): + _debug_prev_msg_hashes = [] + + _prev_n = len(_debug_prev_msg_hashes) + if _prev_n > 0 and _prev_n <= _debug_msg_count: + _first_diff = None + for idx in range(_prev_n): + if _debug_msg_hashes[idx] != _debug_prev_msg_hashes[idx]: + _first_diff = idx + break + if _first_diff is not None: + _diff_msg = _debug_messages[_first_diff] + _diff_role = _diff_msg.get("role", "?") + _diff_content = str(_diff_msg.get("content", "")) + logger.warning( + f"Intercept PREFIX MISMATCH at msg[{_first_diff}] " + f"(role={_diff_role}), " + f"hash was {_debug_prev_msg_hashes[_first_diff]} " + f"now {_debug_msg_hashes[_first_diff]}. " + f"Content preview ({len(_diff_content)} chars): " + f"{_diff_content[:300]}..." + ) + else: + logger.info( + f"Intercept: {_debug_msg_count} msgs (prev={_prev_n}), " + f"prefix[:{_prev_n}] MATCH, " + f"{_debug_msg_count - _prev_n} new msgs" + ) + else: + logger.info(f"Intercept: {_debug_msg_count} msgs (first request or reset)") + + _debug_prev_msg_hashes = list(_debug_msg_hashes) + + # ── Format handler (strategy pattern) ──────────────────────────── + handler = get_format_handler(api_format) + + if config.enabled: + try: + # body is already a fresh copy from _strip_external_content_ids + + # ── Conversation-aware state (single-conversation model) ── + state = _get_intercept_state(body) + + # ── Replace old messages with cached (modified) versions ── + # On subsequent turns, the host sends original (unmodified) + # messages. Replace them with our cached modified versions + # so the inference engine's prefix cache sees identical tokens. + old_msg_count = len(state.cached_messages) + if old_msg_count > 0: + msgs = body.get("messages", []) + if len(msgs) >= old_msg_count: + msgs[:old_msg_count] = copy.deepcopy(state.cached_messages) + logger.info( + f"Intercept: replaced {old_msg_count} old messages " + f"with cached versions for prefix cache consistency" + ) + handler.restore_system(body, state.cached_system) + + multi = handler.extract_all(body, config) + + # ── System prompt: reorder only on first turn ───────────── + if multi.system_extraction and not state.system_processed: + extraction, sys_idx = multi.system_extraction + if len(extraction.documents) >= 2: + reordered_docs, orig_order, new_order = _reorder_documents( + extraction.documents, config + ) + if orig_order != new_order: + reorder_details.append( + { + "source": "system", + "count": len(extraction.documents), + "original_order": orig_order, + "reordered_order": new_order, + } + ) + handler.reconstruct_system( + body, extraction, reordered_docs, sys_idx, config + ) + total_reordered += len(extraction.documents) + system_count = 1 + state.system_processed = True + + # ── Tool results: skip cached old, dedup+reorder new ──────── + for extraction, location in multi.tool_extractions: + if location.msg_index < old_msg_count: + continue + if len(extraction.documents) < 2: + continue + + if not state.first_tool_result_done: + # First tool result in session → reorder for KV cache + state.first_tool_result_done = True + reordered_docs, orig_order, new_order = _reorder_documents( + extraction.documents, config + ) + for doc in extraction.documents: + state.seen_doc_hashes.add(_hash_text(doc)) + if orig_order != new_order: + reorder_details.append( + { + "source": f"tool_result[{location.msg_index}]", + "count": len(extraction.documents), + "original_order": orig_order, + "reordered_order": new_order, + } + ) + handler.reconstruct_tool_result( + body, extraction, reordered_docs, location + ) + total_reordered += len(extraction.documents) + tool_result_count += 1 + else: + # Subsequent tool results → dedup only + new_docs, deduped = _deduplicate_docs(extraction.documents, state) + total_deduped += deduped + if deduped > 0: + if not new_docs: + orig_chars = len(extraction.original_content) + new_docs = [ + f"[All {deduped} documents identical to a " + f"previous tool result ({orig_chars} chars). " + f"Refer to the earlier result above.]" + ] + _chars_before_slim += orig_chars + _chars_after_slim += len(new_docs[0]) + total_slimmed += deduped + reorder_details.append( + { + "source": f"tool_result[{location.msg_index}]", + "count": len(new_docs), + "deduped": deduped, + } + ) + handler.reconstruct_tool_result( + body, extraction, new_docs, location + ) + tool_result_count += 1 + + # ── Single-doc tool results: cross-turn dedup ──────────── + for single_doc, location in multi.single_doc_extractions: + if location.msg_index < old_msg_count: + continue + if single_doc.content_hash in state.single_doc_hashes: + prev_tool_id = state.single_doc_hashes[single_doc.content_hash] + if single_doc.tool_call_id == prev_tool_id: + logger.debug( + f"Intercept: skipping old single-doc at " + f"msg[{location.msg_index}] " + f"({len(single_doc.content)} chars, " + f"preserving prefix cache)" + ) + continue + + if handler.tool_call_present(body, prev_tool_id): + hint = ( + f"[Duplicate content — identical to a previous " + f"tool result ({prev_tool_id}). " + f"Refer to the earlier result above.]" + ) + handler.replace_single_doc(body, location, hint) + total_deduped += 1 + logger.debug( + f"Intercept: deduped single-doc at msg[{location.msg_index}] " + f"(hash={single_doc.content_hash[:12]}…, " + f"original={prev_tool_id})" + ) + else: + state.single_doc_hashes[single_doc.content_hash] = ( + single_doc.tool_call_id + ) + logger.debug( + f"Intercept: original single-doc ({prev_tool_id}) " + f"compacted, keeping re-read at msg[{location.msg_index}]" + ) + else: + state.single_doc_hashes[single_doc.content_hash] = ( + single_doc.tool_call_id + ) + + if ( + total_reordered > 0 + or total_deduped > 0 + or total_slimmed > 0 + or tool_results_skipped > 0 + ): + saved = _chars_before_slim - _chars_after_slim + saved_tokens = saved // 4 if saved > 0 else 0 + logger.info( + f"Intercept ({api_format}): reordered {total_reordered}, " + f"deduped {total_deduped}, slimmed {total_slimmed} " + f"(saved {saved:,} chars ≈ {saved_tokens:,} tokens)" + ) + + _dedup_result = DedupResult() + try: + if api_format == _OPENAI_CHAT: + _dedup_result = dedup_chat_completions(body, chunk_modulus=_chunk_modulus) + elif "input" in body and isinstance(body.get("input"), list): + _dedup_result = dedup_responses_api(body, chunk_modulus=_chunk_modulus) + + if _dedup_result.chars_saved > 0: + _chars_before_slim += _dedup_result.chars_before + _chars_after_slim += _dedup_result.chars_after + logger.info( + f"Dedup ({api_format}): " + f"blocks={_dedup_result.blocks_deduped}/{_dedup_result.blocks_total}, " + f"saved {_dedup_result.chars_saved:,} chars" + ) + except Exception as dedup_err: + logger.warning(f"Dedup failed, continuing: {dedup_err}") + + # ── Cache the final messages array for next turn ────────── + state.cached_messages = copy.deepcopy(body.get("messages", [])) + state.cached_system = handler.cache_system(body) + + except Exception as e: + logger.warning( + f"Intercept extraction/reorder failed, forwarding original: {e}" + ) + total_reordered = 0 + total_deduped = 0 + total_slimmed = 0 + + # In stateful mode, inject ContextPilot request_id as `rid` so SGLang + # uses the same ID for cache tracking (enables eviction sync). + if not _cloud_mode and not _stateless_mode and _index is not None: + request_id = f"req-{uuid.uuid4().hex[:12]}" + body["rid"] = request_id + logger.debug(f"Intercept: injected rid={request_id}") + + # ── Cloud proxy mode: inject cache_control + compute content hash ── + _cloud_content_hash = "" + _cloud_request_id = "" + if _cloud_mode and _cloud_adapter is not None and _ttl_policy is not None: + _ttl_policy.evict_expired() + cached_hashes = _ttl_policy.get_cached_hashes() + body = _cloud_adapter.inject_cache_control(body, cached_hashes) + _cloud_content_hash = hashlib.sha256( + json.dumps( + body.get("system", ""), sort_keys=True, ensure_ascii=False + ).encode() + ).hexdigest()[:24] + _cloud_request_id = f"cloud-{uuid.uuid4().hex[:12]}" + + # Determine target URL + if _cloud_mode and _cloud_adapter is not None: + target_url = _cloud_adapter.get_api_url(_cloud_adapter.get_target_path()) + else: + target_url = f"{infer_api_url}{handler.target_path()}" + + # Build outbound headers: forward everything except X-ContextPilot-* + # and hop-by-hop headers that must not be forwarded by proxies. + _HOP_BY_HOP = frozenset( + ( + "host", + "connection", + "keep-alive", + "transfer-encoding", + "te", + "trailer", + "upgrade", + "proxy-authorization", + "proxy-authenticate", + "content-length", + ) + ) + if _cloud_mode and _cloud_adapter is not None and _cloud_api_key: + outbound_headers = _cloud_adapter.get_auth_headers(_cloud_api_key) + else: + outbound_headers = {} + for k, v in headers.items(): + kl = k.lower() + if kl.startswith("x-contextpilot-"): + continue + if kl in _HOP_BY_HOP: + continue + outbound_headers[k] = v + + # Build ContextPilot metadata as a response header (not in body, + # which would break strict API response parsers like OpenClaw's SDK). + cp_response_headers = {} + _has_activity = ( + total_reordered > 0 + or total_deduped > 0 + or total_slimmed > 0 + or tool_results_skipped > 0 + or _dedup_result.chars_saved > 0 + ) + if _has_activity: + cp_response_headers["X-ContextPilot-Result"] = json.dumps( + { + "intercepted": True, + "documents_reordered": total_reordered > 0, + "total_documents": total_reordered, + "documents_deduplicated": total_deduped, + "documents_slimmed": total_slimmed, + "chars_before_slim": _chars_before_slim, + "chars_after_slim": _chars_after_slim, + "chars_saved": _chars_before_slim - _chars_after_slim, + "tool_results_skipped": tool_results_skipped, + "message_count": state.last_message_count, + "sources": { + "system": system_count, + "tool_results": tool_result_count, + }, + "reorder_details": reorder_details, + "dedup": { + "blocks_deduped": _dedup_result.blocks_deduped, + "blocks_total": _dedup_result.blocks_total, + "chars_saved": _dedup_result.chars_saved, + }, + } + ) + + is_stream = body.get("stream", False) + + _request_start = time.monotonic() + + try: + if is_stream: + # Streaming: passthrough SSE chunks, forwarding status & headers + async def _stream_with_headers(): + _ttft_logged = False + async with session.post( + target_url, json=body, headers=outbound_headers + ) as resp: + # Collect response headers to forward + fwd_headers = dict(cp_response_headers) + for k, v in resp.headers.items(): + kl = k.lower() + if kl in _HOP_BY_HOP or kl == "content-length": + continue + fwd_headers[k] = v + # Yield (headers_dict, status) as first item for the wrapper + yield resp.status, fwd_headers + async for chunk in resp.content.iter_any(): + if not _ttft_logged: + _ttft_ms = (time.monotonic() - _request_start) * 1000 + _saved = _chars_before_slim - _chars_after_slim + _log_ttft(_ttft_ms, total_slimmed, _saved) + _ttft_logged = True + yield chunk + + stream_iter = _stream_with_headers() + first_event = await stream_iter.__anext__() + status, fwd_headers = cast(tuple[int, Dict[str, str]], first_event) + + async def _stream_content_only(): + try: + async for event in stream_iter: + if isinstance(event, bytes): + yield event + finally: + await stream_iter.aclose() + + return StreamingResponse( + _stream_content_only(), + status_code=status, + headers=fwd_headers, + media_type=fwd_headers.get("content-type", "text/event-stream"), + ) + else: + # Non-streaming: forward JSON with metadata in header only + async with session.post( + target_url, json=body, headers=outbound_headers + ) as resp: + _ttft_ms = (time.monotonic() - _request_start) * 1000 + _saved = _chars_before_slim - _chars_after_slim + _log_ttft(_ttft_ms, total_slimmed, _saved) + try: + result = await resp.json() + except (json.JSONDecodeError, aiohttp.ContentTypeError): + text = await resp.text() + raise HTTPException(status_code=resp.status, detail=text[:500]) + + # ── Cloud mode: track cache metrics from response ── + if ( + _cloud_mode + and _cloud_adapter is not None + and _ttl_policy is not None + and _cloud_content_hash + ): + metrics = _cloud_adapter.parse_cache_metrics(result) + _ttl_policy.update_from_response( + metrics, _cloud_request_id, content_hash=_cloud_content_hash + ) + if ( + metrics.cache_read_tokens > 0 + and _index is not None + and _cloud_request_id + ): + node_id = _index._request_to_node.get(_cloud_request_id) + if node_id is not None and node_id in _index.metadata: + _index.metadata[node_id].update_access_time() + cp_response_headers["X-ContextPilot-Cloud-Cache"] = json.dumps( + { + "provider": _cloud_adapter.provider_name, + "cache_creation_tokens": metrics.cache_creation_tokens, + "cache_read_tokens": metrics.cache_read_tokens, + "ttl_stats": _ttl_policy.get_stats(), + } + ) + + return JSONResponse( + content=result, + status_code=resp.status, + headers=cp_response_headers, + ) + + except aiohttp.ClientError as e: + logger.error(f"Error forwarding intercepted request: {e}") + raise HTTPException(status_code=502, detail="Backend connection error") + + +@app.post("/v1/chat/completions") +async def intercept_openai_chat(request: Request): + """Intercept OpenAI chat completions: extract docs, reorder, forward.""" + return await _intercept_and_forward(request, _OPENAI_CHAT) + + +@app.post("/v1/messages") +async def intercept_anthropic_messages(request: Request): + """Intercept Anthropic messages: extract docs, reorder, forward.""" + return await _intercept_and_forward(request, _ANTHROPIC_MESSAGES) + + +@app.get("/cloud/stats") +async def cloud_cache_stats(): + """Get cloud prompt cache statistics (cloud proxy mode only).""" + if not _cloud_mode or _ttl_policy is None: + raise HTTPException( + status_code=404, + detail="Cloud proxy mode not enabled. Start with --cloud-provider.", + ) + _ttl_policy.evict_expired() + stats = _ttl_policy.get_stats() + stats["provider"] = _cloud_adapter.provider_name if _cloud_adapter else None + return JSONResponse(content=stats) + + @app.api_route("/v1/{path:path}", methods=["GET", "POST"]) async def proxy_engine(path: str, request: Request): """ @@ -987,34 +1908,132 @@ async def proxy_engine(path: str, request: Request): detail="Inference API URL not configured. Set CONTEXTPILOT_INFER_API_URL env var or use --infer-api-url.", ) + session = _aiohttp_session + if session is None: + raise HTTPException(status_code=503, detail="HTTP session not initialized") + try: - target_url = f"{infer_api_url}/v1/{path}" + if _cloud_mode and _cloud_adapter is not None: + target_url = _cloud_adapter.get_api_url(f"/v1/{path}") + headers = _cloud_adapter.get_auth_headers(_cloud_api_key or "") + else: + target_url = f"{infer_api_url}/v1/{path}" + headers = {} if request.method == "GET": - async with _aiohttp_session.get(target_url) as response: + async with session.get(target_url, headers=headers) as response: result = await response.json() return JSONResponse(content=result, status_code=response.status) else: body = await request.json() - # Inject rid for SGLang cache tracking (same logic as proxy_completions) - request_id = body.pop("request_id", None) or body.get("rid", None) - if not request_id: - request_id = f"req-{uuid.uuid4().hex[:12]}" - logger.debug(f"Auto-assigned request_id={request_id}") - if _index: - _index.track_request(request_id) - if request_id: - body["rid"] = request_id - body["request_id"] = request_id - - async with _aiohttp_session.post(target_url, json=body) as response: - result = await response.json() - return JSONResponse(content=result, status_code=response.status) + if not _cloud_mode: + request_id = body.pop("request_id", None) or body.get("rid", None) + if not request_id: + request_id = f"req-{uuid.uuid4().hex[:12]}" + logger.debug(f"Auto-assigned request_id={request_id}") + if _index: + _index.track_request(request_id) + if request_id: + body["rid"] = request_id + body["request_id"] = request_id + + body.setdefault("temperature", 0) + if _cloud_mode: + body.setdefault("top_p", 0) + + dedup_result = DedupResult() + try: + if path == "responses" or ( + "input" in body and isinstance(body.get("input"), list) + ): + # Log function_call_output stats before dedup + input_items = body.get("input", []) + fco_items = [ + it + for it in input_items + if isinstance(it, dict) + and it.get("type") == "function_call_output" + ] + if fco_items: + import hashlib as _hl + + fco_summary = [] + for it in fco_items: + out = it.get("output", "") + h = _hl.sha256( + out.encode("utf-8", errors="replace") + ).hexdigest()[:12] + call_id = it.get("call_id", "?") + # Find the tool name from function_call items + fn_name = "?" + for fc in input_items: + if ( + isinstance(fc, dict) + and fc.get("type") == "function_call" + and fc.get("call_id") == it.get("call_id") + ): + fn_name = fc.get("name", "?") + break + content_preview = ( + out[:60].replace("\n", "\\n") if len(out) < 100 else "" + ) + fco_summary.append( + f" call={call_id[:20]} fn={fn_name} len={len(out)} hash={h}" + + (f" [{content_preview}]" if content_preview else "") + ) + logger.info( + f"Request /v1/{path}: {len(input_items)} items, " + f"{len(fco_items)} function_call_output:\n" + + "\n".join(fco_summary) + ) + dedup_result = dedup_responses_api(body, chunk_modulus=_chunk_modulus) + elif "messages" in body and isinstance(body.get("messages"), list): + dedup_result = dedup_chat_completions(body, chunk_modulus=_chunk_modulus) + if dedup_result.chars_saved > 0: + logger.info( + f"Passthrough dedup /v1/{path}: " + f"block={dedup_result.blocks_deduped}/{dedup_result.blocks_total} " + f"(saved {dedup_result.chars_saved:,} chars)" + ) + except Exception as pe: + logger.warning(f"Passthrough dedup failed: {pe}") + + response = await session.post(target_url, json=body, headers=headers) + ct = response.headers.get("content-type", "") + if "text/event-stream" in ct: + + async def _sse_passthrough(): + try: + async for chunk in response.content.iter_any(): + yield chunk + finally: + response.close() + + fwd_hdrs = { + k: v + for k, v in response.headers.items() + if k.lower() + not in ( + "transfer-encoding", + "content-encoding", + "content-length", + ) + } + return StreamingResponse( + _sse_passthrough(), + status_code=response.status, + headers=fwd_hdrs, + ) + result = await response.json() + response.close() + return JSONResponse(content=result, status_code=response.status) except aiohttp.ClientError as e: logger.error(f"Error proxying to inference engine: {e}") - raise HTTPException(status_code=502, detail=f"Inference engine backend error: {str(e)}") + raise HTTPException( + status_code=502, detail=f"Inference engine backend error: {str(e)}" + ) except Exception as e: logger.error(f"Error in proxy: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -1033,6 +2052,11 @@ def main(): # Stateless mode (just clustering/scheduling, no index maintained): python -m contextpilot.server.http_server --port 8765 --stateless --infer-api-url http://localhost:30000 + # Cloud proxy mode (forward to cloud LLM API with prompt cache optimization): + python -m contextpilot.server.http_server --port 8765 --cloud-provider anthropic --cloud-api-key sk-ant-xxx + python -m contextpilot.server.http_server --port 8765 --cloud-provider openai --cloud-api-key sk-xxx + python -m contextpilot.server.http_server --port 8765 --cloud-provider minimax --cloud-api-key xxx + Live mode: - Build context index via POST /reorder - Receive eviction callbacks from inference engine at POST /evict @@ -1043,6 +2067,12 @@ def main(): - Use POST /reorder endpoint for one-off batch reordering - No index maintained, no eviction tracking - Each /reorder call is independent + +Cloud proxy mode: + - Forward to cloud LLM APIs (Anthropic, OpenAI, MiniMax) + - Automatically inject cache_control for prompt cache optimization + - Each provider uses its optimal TTL (Anthropic: 5min, OpenAI: 24hr, MiniMax: 5min) + - GET /cloud/stats for cache hit/miss statistics """, ) parser.add_argument("--host", default="0.0.0.0", help="Host to bind to") @@ -1057,7 +2087,7 @@ def main(): "--stateless", action="store_true", help="Run in stateless mode: clustering/reordering only, no index maintained. " - "Use POST /reorder endpoint for batch reordering.", + "Use POST /reorder endpoint for batch reordering.", ) parser.add_argument( "--infer-api-url", @@ -1077,12 +2107,37 @@ def main(): default=None, help="Model name/path for chat template tokenizer (e.g., 'Qwen/Qwen3-32B')", ) + parser.add_argument( + "--cloud-provider", + type=str, + default=None, + choices=["anthropic", "openai", "minimax"], + help="Cloud LLM provider for cloud proxy mode (anthropic/openai/minimax)", + ) + parser.add_argument( + "--cloud-api-key", + type=str, + default=None, + help="API key for cloud provider (or set CONTEXTPILOT_CLOUD_API_KEY env var)", + ) + parser.add_argument( + "--extended-cache", + action="store_true", + default=False, + help="Use extended cache (Anthropic: 1hr, OpenAI: 24hr, MiniMax: N/A)", + ) + parser.add_argument( + "--chunk-modulus", + type=int, + default=13, + help="Content-level dedup block size (avg lines per block). " + "Smaller = more fine-grained dedup but more pointer overhead. " + "Larger = fewer blocks but may miss partial overlaps. " + "Default 13. Range 7-30 recommended.", + ) args = parser.parse_args() - # Note: --max-tokens is no longer required since eviction is now driven by - # engine's callback, not by ContextPilot's internal tracking - # Set environment variables so they propagate to uvicorn workers if args.max_tokens is not None: os.environ["CONTEXTPILOT_MAX_TOKENS"] = str(args.max_tokens) @@ -1090,12 +2145,19 @@ def main(): os.environ["CONTEXTPILOT_STATELESS_MODE"] = "1" if args.stateless else "0" if args.model: os.environ["CONTEXTPILOT_MODEL_NAME"] = args.model + if args.cloud_provider: + os.environ["CONTEXTPILOT_CLOUD_PROVIDER"] = args.cloud_provider + if args.extended_cache: + os.environ["CONTEXTPILOT_EXTENDED_CACHE"] = "1" + if args.cloud_api_key: + os.environ["CONTEXTPILOT_CLOUD_API_KEY"] = args.cloud_api_key # Also set global config for direct access - global _max_tokens, _infer_api_url, _tokenizer, _model_name, _stateless_mode + global _max_tokens, _infer_api_url, _tokenizer, _model_name, _stateless_mode, _chunk_modulus _max_tokens = args.max_tokens _infer_api_url = args.infer_api_url.rstrip("/") _stateless_mode = args.stateless + _chunk_modulus = args.chunk_modulus # Initialize tokenizer for chat template if args.model and AutoTokenizer is not None: @@ -1112,15 +2174,26 @@ def main(): format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", ) - if _stateless_mode: - logger.info(f"Starting ContextPilot Index Server on {args.host}:{args.port} (STATELESS MODE)") + if args.cloud_provider: + logger.info( + f"Starting ContextPilot Index Server on {args.host}:{args.port} (CLOUD PROXY MODE)" + ) + logger.info(f"Cloud provider: {args.cloud_provider}") + logger.info("GET /cloud/stats for cache statistics") + elif _stateless_mode: + logger.info( + f"Starting ContextPilot Index Server on {args.host}:{args.port} (STATELESS MODE)" + ) logger.info("Stateless mode: clustering/scheduling only, no cache tracking") logger.info("Use POST /reorder endpoint for batch reordering") else: - logger.info(f"Starting ContextPilot Index Server on {args.host}:{args.port} (LIVE MODE)") + logger.info( + f"Starting ContextPilot Index Server on {args.host}:{args.port} (LIVE MODE)" + ) logger.info("Use POST /reorder endpoint for stateful reordering") logger.info("Eviction is driven by engine callback (CONTEXTPILOT_INDEX_URL)") - logger.info(f"Inference backend URL: {_infer_api_url}") + if not args.cloud_provider: + logger.info(f"Inference backend URL: {_infer_api_url}") # Run server uvicorn.run( diff --git a/contextpilot/server/intercept_parser.py b/contextpilot/server/intercept_parser.py new file mode 100644 index 0000000..9dfa19e --- /dev/null +++ b/contextpilot/server/intercept_parser.py @@ -0,0 +1,1055 @@ +""" +HTTP Intercept Parser for ContextPilot + +Pure parsing/extraction/reconstruction logic for intercepting LLM API requests. +Extracts documents from system messages, supports reordering, and reconstructs +the request body with reordered documents. + +No server dependencies — independently testable. +""" + +import hashlib +import json +import re +import copy +import logging +from dataclasses import dataclass, field +from typing import List, Optional, Dict, Any, Tuple + +logger = logging.getLogger(__name__) + +# XML wrapper tag names we recognize (outer wrapper) +_KNOWN_WRAPPER_TAGS = {"documents", "contexts", "docs", "passages", "references", "files"} + +# XML item tag names we recognize (inner items) +_KNOWN_ITEM_TAGS = {"document", "context", "doc", "passage", "reference", "file"} + +# Numbered pattern: [1] ... [2] ... etc. +_NUMBERED_RE = re.compile(r"\[(\d+)\]\s*") + +# Separator patterns for auto-detection +_SEPARATOR_PATTERNS = ["---", "==="] + +# Minimum content length to track a single-doc tool_result for dedup. +# Skips tiny results like "ok", error messages, short status outputs. +_SINGLE_DOC_MIN_CHARS = 200 + + +@dataclass +class InterceptConfig: + """Configuration parsed from X-ContextPilot-* headers.""" + + enabled: bool = True + mode: str = "auto" # xml_tag, separator, numbered, markdown_header, auto + tag: str = "document" # XML tag name for xml_tag mode + separator: str = "---" # Delimiter for separator mode + alpha: float = 0.001 + linkage_method: str = "average" + scope: str = "all" # "system", "tool_results", "all" + + +@dataclass +class ExtractionResult: + """Result of extracting documents from a message.""" + + documents: List[str] + prefix: str = "" # Text before the documents block + suffix: str = "" # Text after the documents block + mode: str = "" # Which extraction mode matched + # XML-specific + wrapper_tag: str = "" # e.g. "documents" + item_tag: str = "" # e.g. "document" + # Separator-specific + separator_char: str = "" + # Original content for fallback + original_content: str = "" + # JSON-results-specific: full result objects (documents = content strings) + json_items: Optional[List[Any]] = None + + +@dataclass +class ToolResultLocation: + """Identifies a tool_result's position in the messages array.""" + msg_index: int + block_index: int = -1 # -1 = content is string + inner_block_index: int = -1 # For Anthropic nested content blocks + + +@dataclass +class SingleDocExtraction: + """A tool_result containing a single document (e.g., file read, single API response). + + Not reorderable (only 1 item), but trackable for cross-turn deduplication. + The content_hash enables efficient comparison without storing full content. + """ + content: str + content_hash: str # SHA-256 of stripped content + tool_call_id: str = "" # tool_call_id (OpenAI) or tool_use_id (Anthropic) + + +@dataclass +class MultiExtractionResult: + """Aggregates extractions from system prompt and tool_result messages.""" + system_extraction: Optional[Tuple["ExtractionResult", int]] = None + tool_extractions: List[Tuple["ExtractionResult", ToolResultLocation]] = field(default_factory=list) + single_doc_extractions: List[Tuple["SingleDocExtraction", ToolResultLocation]] = field(default_factory=list) + + @property + def has_extractions(self) -> bool: + return (self.system_extraction is not None + or len(self.tool_extractions) > 0 + or len(self.single_doc_extractions) > 0) + + @property + def total_documents(self) -> int: + total = len(self.single_doc_extractions) + if self.system_extraction: + total += len(self.system_extraction[0].documents) + for ext, _ in self.tool_extractions: + total += len(ext.documents) + return total + + +def _safe_float(value: str, default: float) -> float: + try: + return float(value) + except (ValueError, TypeError): + return default + + +def parse_intercept_headers(headers: Dict[str, str]) -> InterceptConfig: + """Parse X-ContextPilot-* headers into an InterceptConfig.""" + def get(name: str, default: str = "") -> str: + # Headers are case-insensitive; try common casings + key = f"x-contextpilot-{name}" + for k, v in headers.items(): + if k.lower() == key: + return v + return default + + enabled_str = get("enabled", "true").lower() + enabled = enabled_str not in ("false", "0", "no") + + scope = get("scope", "all").lower() + if scope not in ("system", "tool_results", "all"): + scope = "all" + + return InterceptConfig( + enabled=enabled, + mode=get("mode", "auto").lower(), + tag=get("tag", "document").lower(), + separator=get("separator", "---"), + alpha=_safe_float(get("alpha", "0.001"), 0.001), + linkage_method=get("linkage", "average"), + scope=scope, + ) + + +# ── Document extraction ───────────────────────────────────────────────────── + + +def _extract_xml_tags(text: str, config: InterceptConfig) -> Optional[ExtractionResult]: + """Extract documents from XML-tagged blocks. + + Supports patterns like: + ...... + Also handles custom tags and known alternatives (contexts, docs, passages, references). + """ + # Determine which wrapper/item tags to try + if config.mode == "xml_tag": + # User specified xml_tag mode — try their custom tag first + item_tags_to_try = [config.tag] + wrapper_tags_to_try = [config.tag + "s"] # e.g. "document" -> "documents" + # Also add known tags as fallback + item_tags_to_try.extend(t for t in _KNOWN_ITEM_TAGS if t != config.tag) + wrapper_tags_to_try.extend(t for t in _KNOWN_WRAPPER_TAGS if t != config.tag + "s") + else: + item_tags_to_try = list(_KNOWN_ITEM_TAGS) + wrapper_tags_to_try = list(_KNOWN_WRAPPER_TAGS) + + # Try with wrapper tags first (e.g. .........) + for wrapper_tag in wrapper_tags_to_try: + wrapper_pattern = re.compile( + rf"(<{wrapper_tag}(?:\s[^>]*)?>)(.*?)()", + re.DOTALL, + ) + wrapper_match = wrapper_pattern.search(text) + if not wrapper_match: + continue + + inner_text = wrapper_match.group(2) + prefix = text[: wrapper_match.start()] + suffix = text[wrapper_match.end() :] + + # Try each item tag inside the wrapper + for item_tag in item_tags_to_try: + item_pattern = re.compile( + rf"<{item_tag}(?:\s[^>]*)?>(.+?)", + re.DOTALL, + ) + items = item_pattern.findall(inner_text) + if items: + return ExtractionResult( + documents=[item.strip() for item in items], + prefix=prefix, + suffix=suffix, + mode="xml_tag", + wrapper_tag=wrapper_tag, + item_tag=item_tag, + original_content=text, + ) + + # Try without wrapper (just repeated item tags) + for item_tag in item_tags_to_try: + item_pattern = re.compile( + rf"<{item_tag}(?:\s[^>]*)?>(.+?)", + re.DOTALL, + ) + items = list(item_pattern.finditer(text)) + if len(items) >= 2: + first_start = items[0].start() + last_end = items[-1].end() + return ExtractionResult( + documents=[m.group(1).strip() for m in items], + prefix=text[:first_start], + suffix=text[last_end:], + mode="xml_tag", + wrapper_tag="", + item_tag=item_tag, + original_content=text, + ) + + return None + + +def _extract_numbered(text: str, config: InterceptConfig) -> Optional[ExtractionResult]: + """Extract documents from numbered format: [1] doc text [2] doc text ...""" + splits = _NUMBERED_RE.split(text) + # splits will be like: [prefix, "1", doc1, "2", doc2, ...] + # If we found numbered items, splits has at least 4 elements (prefix + one item) + if len(splits) < 4: + return None + + prefix = splits[0] + documents = [] + i = 1 + while i + 1 < len(splits): + # splits[i] is the number, splits[i+1] is the content + doc_text = splits[i + 1].strip() + if doc_text: + documents.append(doc_text) + i += 2 + + if len(documents) < 2: + return None + + return ExtractionResult( + documents=documents, + prefix=prefix, + suffix="", + mode="numbered", + original_content=text, + ) + + +def _extract_separator( + text: str, config: InterceptConfig +) -> Optional[ExtractionResult]: + """Extract documents separated by delimiters (--- or ===).""" + sep = config.separator + # For auto mode, try common separators + if config.mode == "auto": + for candidate in _SEPARATOR_PATTERNS: + # Need the separator to appear on its own line (or at boundary) + parts = re.split(r"\n" + re.escape(candidate) + r"\n", text) + if len(parts) >= 3: + sep = candidate + break + else: + return None + documents = [p.strip() for p in parts if p.strip()] + else: + parts = re.split(r"\n" + re.escape(sep) + r"\n", text) + documents = [p.strip() for p in parts if p.strip()] + + if len(documents) < 2: + return None + + return ExtractionResult( + documents=documents, + prefix="", + suffix="", + mode="separator", + separator_char=sep, + original_content=text, + ) + + +def _extract_markdown_headers( + text: str, config: InterceptConfig +) -> Optional[ExtractionResult]: + """Extract documents by splitting on markdown headers (# or ##). + + Each header + its body becomes one document. Requires >= 2 sections. + Text before the first header is preserved as prefix. + """ + # Split on lines that start with # or ## (but not ### or deeper) + parts = re.split(r"(?=^#{1,2}\s)", text, flags=re.MULTILINE) + # parts[0] is text before first header (prefix), rest are sections + if not parts: + return None + + prefix = "" + sections = [] + for part in parts: + stripped = part.strip() + if not stripped: + continue + if re.match(r"^#{1,2}\s", stripped): + sections.append(stripped) + else: + # Text before first header + prefix = part + + if len(sections) < 2: + return None + + return ExtractionResult( + documents=sections, + prefix=prefix, + suffix="", + mode="markdown_header", + original_content=text, + ) + + +def _reconstruct_markdown_headers( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + """Reconstruct markdown-header-split content.""" + parts = [] + if extraction.prefix.strip(): + parts.append(extraction.prefix.rstrip()) + parts.extend(reordered_docs) + return "\n\n".join(parts) + + +# Keys that carry a URL / path identifier suitable for clustering. +_JSON_ID_KEYS = ("url", "path", "file", "filename", "uri", "href") + + +def _extract_json_id(item: dict) -> Optional[str]: + """Extract a URL or path identifier from a JSON result item for clustering.""" + for key in _JSON_ID_KEYS: + val = item.get(key) + if isinstance(val, str) and val.strip(): + return val.strip() + return None + + +def _extract_json_results( + text: str, config: InterceptConfig +) -> Optional[ExtractionResult]: + """Extract documents from JSON tool results with a ``results`` array. + + OpenClaw tools (memory search, web search, web fetch) return + ``JSON.stringify(payload, null, 2)`` where *payload* contains a + ``results`` list. The content field (description/snippet/content/text) + of each item is used for clustering; the full objects are stored in + ``json_items`` so reconstruction can reorder them intact. + """ + stripped = text.strip() + if not stripped.startswith("{"): + return None + try: + obj = json.loads(stripped) + except (json.JSONDecodeError, ValueError): + return None + if not isinstance(obj, dict): + return None + results = obj.get("results") + if not isinstance(results, list) or len(results) < 2: + return None + + # Use url/path as document identifier for clustering — short, stable, + # and same-site results share shingles → cluster together naturally. + # Falls back to full serialised object when no id key is found. + documents = [] + for item in results: + if isinstance(item, dict): + doc_id = _extract_json_id(item) + if doc_id is not None: + documents.append(doc_id) + else: + documents.append(json.dumps(item, ensure_ascii=False)) + else: + documents.append(json.dumps(item, ensure_ascii=False)) + + if len(documents) < 2: + return None + + return ExtractionResult( + documents=documents, + prefix="", + suffix="", + mode="json_results", + original_content=text, + json_items=results, + ) + + +def extract_documents( + text: str, config: InterceptConfig +) -> Optional[ExtractionResult]: + """Extract documents from text using the configured mode. + + Auto-detection priority: xml_tag > numbered > json_results. + ``separator`` and ``markdown_header`` are only used when explicitly + requested — they match structural content (YAML frontmatter, prompt + sections) too aggressively for auto mode. + Returns None if no documents are found (caller should bypass). + """ + if config.mode == "xml_tag": + return _extract_xml_tags(text, config) + elif config.mode == "numbered": + return _extract_numbered(text, config) + elif config.mode == "json_results": + return _extract_json_results(text, config) + elif config.mode == "separator": + return _extract_separator(text, config) + elif config.mode == "markdown_header": + return _extract_markdown_headers(text, config) + else: + # Auto mode: only formats that clearly delimit independent documents. + # separator and markdown_header excluded — too aggressive on + # structural content (YAML frontmatter, prompt sections). + result = _extract_xml_tags(text, config) + if result: + return result + result = _extract_numbered(text, config) + if result: + return result + result = _extract_json_results(text, config) + if result: + return result + return None + + +# ── Reconstruction ─────────────────────────────────────────────────────────── + + +def reconstruct_content( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + """Reconstruct the message content with reordered documents. + + Preserves the original format (XML tags, numbering, separators, markdown headers). + """ + if extraction.mode == "xml_tag": + return _reconstruct_xml(extraction, reordered_docs) + elif extraction.mode == "numbered": + return _reconstruct_numbered(extraction, reordered_docs) + elif extraction.mode == "json_results": + return _reconstruct_json_results(extraction, reordered_docs) + elif extraction.mode == "separator": + return _reconstruct_separator(extraction, reordered_docs) + elif extraction.mode == "markdown_header": + return _reconstruct_markdown_headers(extraction, reordered_docs) + else: + # Should not happen, but fallback + return extraction.original_content + + +def _reconstruct_xml( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + item_tag = extraction.item_tag + items = "\n".join(f"<{item_tag}>{doc}" for doc in reordered_docs) + + if extraction.wrapper_tag: + wrapper = extraction.wrapper_tag + block = f"<{wrapper}>\n{items}\n" + else: + block = items + + return extraction.prefix + block + extraction.suffix + + +def _reconstruct_numbered( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + parts = [extraction.prefix] if extraction.prefix else [] + for i, doc in enumerate(reordered_docs, 1): + parts.append(f"[{i}] {doc}") + result = "\n".join(parts) if parts else "" + if extraction.suffix: + result += extraction.suffix + return result + + +def _reconstruct_json_results( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + """Reconstruct a JSON tool result with reordered ``results`` array. + + When ``json_items`` is set (content-key extraction), maps reordered + content strings back to their full result objects via index lookup. + Falls back to parsing reordered_docs as JSON for backward compat. + """ + obj = json.loads(extraction.original_content) + if extraction.json_items is not None: + # Build content → original-index mapping + orig_docs = extraction.documents + doc_to_indices: dict = {} + for i, doc in enumerate(orig_docs): + doc_to_indices.setdefault(doc, []).append(i) + used: set = set() + reordered_items = [] + for doc in reordered_docs: + for idx in doc_to_indices.get(doc, []): + if idx not in used: + reordered_items.append(extraction.json_items[idx]) + used.add(idx) + break + obj["results"] = reordered_items + else: + obj["results"] = [json.loads(doc) for doc in reordered_docs] + return json.dumps(obj, indent=2, ensure_ascii=False) + + +def _reconstruct_separator( + extraction: ExtractionResult, reordered_docs: List[str] +) -> str: + sep = extraction.separator_char or "---" + return ("\n" + sep + "\n").join(reordered_docs) + + +# ── OpenAI Chat format ────────────────────────────────────────────────────── + + +def extract_from_openai_chat( + body: Dict[str, Any], config: InterceptConfig +) -> Optional[Tuple[ExtractionResult, int]]: + """Extract documents from an OpenAI chat completions request body. + + Looks for the system message and extracts documents from its content. + + Returns: + Tuple of (ExtractionResult, system_message_index) or None. + """ + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return None + + for i, msg in enumerate(messages): + if msg.get("role") != "system": + continue + content = msg.get("content", "") + if isinstance(content, str): + result = extract_documents(content, config) + if result: + return result, i + elif isinstance(content, list): + # Content blocks (e.g. [{type: "text", text: "..."}]) + for block in content: + if isinstance(block, dict) and block.get("type") == "text": + result = extract_documents(block.get("text", ""), config) + if result: + return result, i + return None + + +def reconstruct_openai_chat( + body: Dict[str, Any], + extraction: ExtractionResult, + reordered_docs: List[str], + system_msg_index: int, + config: Optional[InterceptConfig] = None, +) -> Dict[str, Any]: + """Reconstruct an OpenAI chat completions request body with reordered docs.""" + body = copy.deepcopy(body) + new_content = reconstruct_content(extraction, reordered_docs) + msg = body["messages"][system_msg_index] + + if isinstance(msg.get("content"), str): + msg["content"] = new_content + elif isinstance(msg.get("content"), list): + cfg = config or InterceptConfig() + for block in msg["content"]: + if isinstance(block, dict) and block.get("type") == "text": + if extract_documents(block.get("text", ""), cfg): + block["text"] = new_content + break + return body + + +# ── Anthropic Messages format ─────────────────────────────────────────────── + + +def extract_from_anthropic_messages( + body: Dict[str, Any], config: InterceptConfig +) -> Optional[ExtractionResult]: + """Extract documents from an Anthropic messages request body. + + Looks at body["system"] which can be a string or list of content blocks. + + Returns: + ExtractionResult or None. + """ + system = body.get("system") + if system is None: + return None + + if isinstance(system, str): + return extract_documents(system, config) + elif isinstance(system, list): + # Content blocks: [{type: "text", text: "..."}] + for block in system: + if isinstance(block, dict) and block.get("type") == "text": + result = extract_documents(block.get("text", ""), config) + if result: + return result + return None + + +def reconstruct_anthropic_messages( + body: Dict[str, Any], + extraction: ExtractionResult, + reordered_docs: List[str], + config: Optional[InterceptConfig] = None, +) -> Dict[str, Any]: + """Reconstruct an Anthropic messages request body with reordered docs.""" + body = copy.deepcopy(body) + new_content = reconstruct_content(extraction, reordered_docs) + + if isinstance(body.get("system"), str): + body["system"] = new_content + elif isinstance(body.get("system"), list): + cfg = config or InterceptConfig() + for block in body["system"]: + if isinstance(block, dict) and block.get("type") == "text": + if extract_documents(block.get("text", ""), cfg): + block["text"] = new_content + break + return body + + +# ── Tool result extraction ───────────────────────────────────────────────── + + +def extract_from_openai_tool_results( + body: Dict[str, Any], config: InterceptConfig +) -> List[Tuple[ExtractionResult, ToolResultLocation]]: + """Extract documents from OpenAI tool result messages (role=="tool"). + + Returns a list of (ExtractionResult, ToolResultLocation) for each tool + result message that contains extractable documents. + """ + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return [] + + results = [] + for i, msg in enumerate(messages): + if msg.get("role") not in ("tool", "toolResult"): + continue + content = msg.get("content", "") + if isinstance(content, str): + extraction = extract_documents(content, config) + if extraction and len(extraction.documents) >= 2: + loc = ToolResultLocation(msg_index=i) + results.append((extraction, loc)) + elif isinstance(content, list): + for j, block in enumerate(content): + if isinstance(block, dict) and block.get("type") == "text": + extraction = extract_documents(block.get("text", ""), config) + if extraction and len(extraction.documents) >= 2: + loc = ToolResultLocation(msg_index=i, block_index=j) + results.append((extraction, loc)) + return results + + +def extract_from_anthropic_tool_results( + body: Dict[str, Any], config: InterceptConfig +) -> List[Tuple[ExtractionResult, ToolResultLocation]]: + """Extract documents from Anthropic tool_result content blocks. + + Anthropic tool results appear as messages with role=="user" containing + content blocks of type=="tool_result". + """ + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return [] + + results = [] + for i, msg in enumerate(messages): + if msg.get("role") != "user": + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for j, block in enumerate(content): + if not isinstance(block, dict) or block.get("type") not in ("tool_result", "toolResult"): + continue + tr_content = block.get("content", "") + if isinstance(tr_content, str): + extraction = extract_documents(tr_content, config) + if extraction and len(extraction.documents) >= 2: + loc = ToolResultLocation(msg_index=i, block_index=j) + results.append((extraction, loc)) + elif isinstance(tr_content, list): + for k, inner in enumerate(tr_content): + if isinstance(inner, dict) and inner.get("type") == "text": + extraction = extract_documents(inner.get("text", ""), config) + if extraction and len(extraction.documents) >= 2: + loc = ToolResultLocation(msg_index=i, block_index=j, inner_block_index=k) + results.append((extraction, loc)) + return results + + +# ── Tool result reconstruction ───────────────────────────────────────────── + + +def reconstruct_openai_tool_result( + body: Dict[str, Any], + extraction: ExtractionResult, + reordered_docs: List[str], + location: ToolResultLocation, +) -> None: + """Reconstruct an OpenAI tool result message in-place.""" + new_content = reconstruct_content(extraction, reordered_docs) + msg = body["messages"][location.msg_index] + if location.block_index == -1: + msg["content"] = new_content + else: + msg["content"][location.block_index]["text"] = new_content + + +def reconstruct_anthropic_tool_result( + body: Dict[str, Any], + extraction: ExtractionResult, + reordered_docs: List[str], + location: ToolResultLocation, +) -> None: + """Reconstruct an Anthropic tool_result content block in-place.""" + new_content = reconstruct_content(extraction, reordered_docs) + msg = body["messages"][location.msg_index] + block = msg["content"][location.block_index] + if location.inner_block_index == -1: + block["content"] = new_content + else: + block["content"][location.inner_block_index]["text"] = new_content + + +# ── Aggregate extraction ─────────────────────────────────────────────────── + + +def extract_all_openai( + body: Dict[str, Any], config: InterceptConfig +) -> MultiExtractionResult: + """Extract documents from both system message and tool results (OpenAI format).""" + result = MultiExtractionResult() + if config.scope in ("system", "all"): + sys_result = extract_from_openai_chat(body, config) + if sys_result: + result.system_extraction = sys_result + if config.scope in ("tool_results", "all"): + result.tool_extractions = extract_from_openai_tool_results(body, config) + result.single_doc_extractions = extract_single_docs_from_openai_tool_results( + body, config) + return result + + +def extract_all_anthropic( + body: Dict[str, Any], config: InterceptConfig +) -> MultiExtractionResult: + """Extract documents from both system prompt and tool results (Anthropic format).""" + result = MultiExtractionResult() + if config.scope in ("system", "all"): + sys_extraction = extract_from_anthropic_messages(body, config) + if sys_extraction and len(sys_extraction.documents) >= 2: + # Use -1 as sentinel for "system field" (not in messages array) + result.system_extraction = (sys_extraction, -1) + if config.scope in ("tool_results", "all"): + result.tool_extractions = extract_from_anthropic_tool_results(body, config) + result.single_doc_extractions = extract_single_docs_from_anthropic_tool_results( + body, config) + return result + + +# ── Single-document extraction (for cross-turn dedup) ───────────────────── + + +def _make_single_doc(content: str, tool_call_id: str = "") -> SingleDocExtraction: + """Create a SingleDocExtraction with content hash.""" + stripped = content.strip() + content_hash = hashlib.sha256(stripped.encode()).hexdigest() + return SingleDocExtraction( + content=stripped, + content_hash=content_hash, + tool_call_id=tool_call_id, + ) + + +def extract_single_docs_from_openai_tool_results( + body: Dict[str, Any], config: InterceptConfig +) -> List[Tuple[SingleDocExtraction, ToolResultLocation]]: + """Extract single-document tool results for dedup tracking (OpenAI format). + + Only captures tool_results where multi-doc extraction failed — these are + individual file reads, single API responses, etc. that contain substantial + content worth deduplicating across turns. + """ + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return [] + + results = [] + for i, msg in enumerate(messages): + if msg.get("role") not in ("tool", "toolResult"): + continue + tool_call_id = msg.get("tool_call_id", "") + content = msg.get("content", "") + + if isinstance(content, str): + # Skip if multi-doc extraction would succeed + extraction = extract_documents(content, config) + if extraction and len(extraction.documents) >= 2: + continue + if len(content.strip()) >= _SINGLE_DOC_MIN_CHARS: + loc = ToolResultLocation(msg_index=i) + results.append((_make_single_doc(content, tool_call_id), loc)) + elif isinstance(content, list): + for j, block in enumerate(content): + if not isinstance(block, dict) or block.get("type") != "text": + continue + text = block.get("text", "") + extraction = extract_documents(text, config) + if extraction and len(extraction.documents) >= 2: + continue + if len(text.strip()) >= _SINGLE_DOC_MIN_CHARS: + loc = ToolResultLocation(msg_index=i, block_index=j) + results.append((_make_single_doc(text, tool_call_id), loc)) + return results + + +def extract_single_docs_from_anthropic_tool_results( + body: Dict[str, Any], config: InterceptConfig +) -> List[Tuple[SingleDocExtraction, ToolResultLocation]]: + """Extract single-document tool results for dedup tracking (Anthropic format). + + Anthropic tool results appear as user messages with type=="tool_result" blocks. + """ + messages = body.get("messages") + if not messages or not isinstance(messages, list): + return [] + + results = [] + for i, msg in enumerate(messages): + if msg.get("role") != "user": + continue + content = msg.get("content") + if not isinstance(content, list): + continue + for j, block in enumerate(content): + if not isinstance(block, dict): + continue + if block.get("type") not in ("tool_result", "toolResult"): + continue + tool_use_id = block.get("tool_use_id", "") + tr_content = block.get("content", "") + + if isinstance(tr_content, str): + extraction = extract_documents(tr_content, config) + if extraction and len(extraction.documents) >= 2: + continue + if len(tr_content.strip()) >= _SINGLE_DOC_MIN_CHARS: + loc = ToolResultLocation(msg_index=i, block_index=j) + results.append((_make_single_doc(tr_content, tool_use_id), loc)) + elif isinstance(tr_content, list): + for k, inner in enumerate(tr_content): + if not isinstance(inner, dict) or inner.get("type") != "text": + continue + text = inner.get("text", "") + extraction = extract_documents(text, config) + if extraction and len(extraction.documents) >= 2: + continue + if len(text.strip()) >= _SINGLE_DOC_MIN_CHARS: + loc = ToolResultLocation(msg_index=i, block_index=j, inner_block_index=k) + results.append((_make_single_doc(text, tool_use_id), loc)) + return results + + +# ── Single-document hint replacement ────────────────────────────────────── + + +def replace_single_doc_openai( + body: Dict[str, Any], location: ToolResultLocation, hint: str, +) -> None: + """Replace a single-doc OpenAI tool result content with a dedup hint in-place.""" + msg = body["messages"][location.msg_index] + if location.block_index == -1: + msg["content"] = hint + else: + msg["content"][location.block_index]["text"] = hint + + +def replace_single_doc_anthropic( + body: Dict[str, Any], location: ToolResultLocation, hint: str, +) -> None: + """Replace a single-doc Anthropic tool_result content with a dedup hint in-place.""" + msg = body["messages"][location.msg_index] + block = msg["content"][location.block_index] + if location.inner_block_index == -1: + block["content"] = hint + else: + block["content"][location.inner_block_index]["text"] = hint + + +# ── Format handler abstraction ───────────────────────────────────────────── +# Strategy pattern: encapsulates all format-specific operations so the +# intercept logic in http_server.py can be completely format-agnostic. + + +from abc import ABC, abstractmethod + + +class FormatHandler(ABC): + """Abstract handler for API format-specific operations. + + Implement one per LLM API format (OpenAI Chat, Anthropic Messages, etc.). + The intercept pipeline calls these methods instead of branching on format. + """ + + @abstractmethod + def extract_all(self, body: Dict[str, Any], + config: InterceptConfig) -> MultiExtractionResult: ... + + @abstractmethod + def reconstruct_system(self, body: Dict[str, Any], + extraction: ExtractionResult, + docs: List[str], sys_idx: int, + config: Optional["InterceptConfig"] = None) -> None: ... + + @abstractmethod + def reconstruct_tool_result(self, body: Dict[str, Any], + extraction: ExtractionResult, + docs: List[str], + location: ToolResultLocation) -> None: ... + + @abstractmethod + def replace_single_doc(self, body: Dict[str, Any], + location: ToolResultLocation, + hint: str) -> None: ... + + @abstractmethod + def tool_call_present(self, body: Dict[str, Any], + tool_call_id: str) -> bool: ... + + @abstractmethod + def target_path(self) -> str: ... + + @abstractmethod + def cache_system(self, body: Dict[str, Any]) -> Any: ... + + @abstractmethod + def restore_system(self, body: Dict[str, Any], cached: Any) -> None: ... + + +class OpenAIChatHandler(FormatHandler): + """Handler for OpenAI Chat Completions format.""" + + def extract_all(self, body, config): + return extract_all_openai(body, config) + + def reconstruct_system(self, body, extraction, docs, sys_idx, config=None): + new_content = reconstruct_content(extraction, docs) + msg = body["messages"][sys_idx] + if isinstance(msg.get("content"), str): + msg["content"] = new_content + elif isinstance(msg.get("content"), list): + cfg = config or InterceptConfig() + for block in msg["content"]: + if isinstance(block, dict) and block.get("type") == "text": + if extract_documents(block.get("text", ""), cfg): + block["text"] = new_content + break + + def reconstruct_tool_result(self, body, extraction, docs, location): + reconstruct_openai_tool_result(body, extraction, docs, location) + + def replace_single_doc(self, body, location, hint): + replace_single_doc_openai(body, location, hint) + + def tool_call_present(self, body, tool_call_id): + for msg in (body.get("messages") or []): + if msg.get("role") in ("tool", "toolResult"): + if msg.get("tool_call_id") == tool_call_id: + return True + return False + + def target_path(self): + return "/v1/chat/completions" + + def cache_system(self, body): + return None # System prompt is inside messages array + + def restore_system(self, body, cached): + pass # No-op: cached with messages + + +class AnthropicMessagesHandler(FormatHandler): + """Handler for Anthropic Messages API format.""" + + def extract_all(self, body, config): + return extract_all_anthropic(body, config) + + def reconstruct_system(self, body, extraction, docs, sys_idx, config=None): + new_content = reconstruct_content(extraction, docs) + if isinstance(body.get("system"), str): + body["system"] = new_content + elif isinstance(body.get("system"), list): + cfg = config or InterceptConfig() + for block in body["system"]: + if isinstance(block, dict) and block.get("type") == "text": + if extract_documents(block.get("text", ""), cfg): + block["text"] = new_content + break + + def reconstruct_tool_result(self, body, extraction, docs, location): + reconstruct_anthropic_tool_result(body, extraction, docs, location) + + def replace_single_doc(self, body, location, hint): + replace_single_doc_anthropic(body, location, hint) + + def tool_call_present(self, body, tool_call_id): + for msg in (body.get("messages") or []): + if msg.get("role") == "user" and isinstance(msg.get("content"), list): + for block in msg["content"]: + if (isinstance(block, dict) + and block.get("type") in ("tool_result", "toolResult") + and block.get("tool_use_id") == tool_call_id): + return True + return False + + def target_path(self): + return "/v1/messages" + + def cache_system(self, body): + return copy.deepcopy(body.get("system")) + + def restore_system(self, body, cached): + if cached is not None: + body["system"] = copy.deepcopy(cached) + + +# Handler registry — add new formats here. +_FORMAT_HANDLERS: Dict[str, FormatHandler] = { + "openai_chat": OpenAIChatHandler(), + "anthropic_messages": AnthropicMessagesHandler(), +} + + +def get_format_handler(api_format: str) -> FormatHandler: + """Return the handler for the given API format, defaulting to OpenAI.""" + return _FORMAT_HANDLERS.get(api_format, _FORMAT_HANDLERS["openai_chat"]) diff --git a/contextpilot/server/live_index.py b/contextpilot/server/live_index.py index cbaeb31..2952942 100644 --- a/contextpilot/server/live_index.py +++ b/contextpilot/server/live_index.py @@ -898,26 +898,21 @@ def schedule_only(self, contexts: List[List[int]]) -> Dict: Returns: Dictionary with scheduled results (no request_id mapping) """ - print("=" * 80) - print("SCHEDULING BATCH (STATELESS MODE)") - print("=" * 80) - + logger.debug("Scheduling batch") + # Step 1: Build static index (clustering + reordering) - print("\n1. Building static index...") result = self.fit_transform(contexts) - - print(f" ✓ Built tree with {result.stats['total_nodes']} nodes") - print(f" ✓ Leaf nodes: {result.stats['leaf_nodes']}") - + logger.debug( + f"Built tree: {result.stats['total_nodes']} nodes, " + f"{result.stats['leaf_nodes']} leaves" + ) + # Step 2: Inter-context scheduling - print("\n2. Scheduling contexts for optimal execution...") scheduled_reordered, scheduled_originals, final_mapping, groups = \ self.inter_scheduler.schedule_contexts(result) - - print(f" ✓ Created {len(groups)} execution groups") - - # Return results without going live (stateless) - scheduled_result = { + logger.debug(f"Created {len(groups)} execution groups") + + return { 'reordered_contexts': scheduled_reordered, 'original_indices': final_mapping, 'scheduled_originals': scheduled_originals, @@ -929,12 +924,6 @@ def schedule_only(self, contexts: List[List[int]]) -> Dict: 'num_groups': len(groups), } } - - print("\n" + "=" * 80) - print("✓ BATCH SCHEDULED (Stateless - no cache tracking)") - print("=" * 80 + "\n") - - return scheduled_result def _initialize_live_metadata(self, initial_tokens_per_context: int, num_input_contexts: int = None) -> Tuple[Dict[str, int], List[str]]: """ diff --git a/contextpilot/server/ttl_eviction.py b/contextpilot/server/ttl_eviction.py new file mode 100644 index 0000000..08d2d99 --- /dev/null +++ b/contextpilot/server/ttl_eviction.py @@ -0,0 +1,324 @@ +""" +TTL-based Eviction Policy for Cloud Prompt Cache Proxy + +Models the cache state of cloud LLM providers (Anthropic, OpenAI, MiniMax) +locally. Cloud providers evict cached prompts after a TTL window: +- Anthropic ephemeral: ~5 minutes +- OpenAI automatic: ~5-10 minutes +- MiniMax: configurable (5 min default) + +This module tracks what content is currently cached in the cloud provider +so ContextPilot can optimize document ordering to maximize cache hits. + +Key Design: +- TTL-based expiry +- Thread-safe for concurrent request handling +- Tracks hit/miss statistics for monitoring +- Supports two TTL tiers: SHORT (5 min) and LONG (1 hr) +""" + +import logging +import time +import threading +from dataclasses import dataclass, field +from enum import Enum +from typing import Dict, List, Optional, Set + +logger = logging.getLogger(__name__) + + +class TTLTier(str, Enum): + SHORT = "5m" # 300 seconds + MEDIUM = "1h" # 3600 seconds + LONG = "24h" # 86400 seconds + + @property + def seconds(self) -> int: + return {"5m": 300, "1h": 3600, "24h": 86400}[self.value] + + +@dataclass +class CacheEntry: + """A single cached content entry with TTL tracking.""" + + content_hash: str + request_id: str + created_at: float + last_accessed_at: float + ttl_seconds: int + token_count: int = 0 + + def is_expired(self, now: Optional[float] = None) -> bool: + """Check if this entry has expired.""" + if now is None: + now = time.time() + return (now - self.last_accessed_at) >= self.ttl_seconds + + def time_remaining(self, now: Optional[float] = None) -> float: + """Get seconds remaining before expiry. Negative if expired.""" + if now is None: + now = time.time() + return self.ttl_seconds - (now - self.last_accessed_at) + + def __repr__(self): + remaining = self.time_remaining() + status = f"{remaining:.0f}s left" if remaining > 0 else "EXPIRED" + return ( + f"CacheEntry(hash={self.content_hash[:12]}..., " + f"tokens={self.token_count}, {status})" + ) + + +@dataclass +class CacheMetrics: + """Cache usage metrics from a cloud API response.""" + + cache_creation_tokens: int = 0 + cache_read_tokens: int = 0 + input_tokens: int = 0 + output_tokens: int = 0 + + +class TTLEvictionPolicy: + """ + TTL-based cache eviction policy for cloud prompt cache proxy. + + Models the cloud provider's cache state locally. Entries expire + after their TTL window (measured from last access time), mirroring + how cloud providers evict cached prompts. + + Thread-safe for concurrent request handling. + + Usage: + policy = TTLEvictionPolicy(default_ttl=TTLTier.SHORT) + + # Track cached content + policy.add_entry("abc123", token_count=5000) + + # Check if content is still cached + if policy.is_cached("abc123"): + print("Cache hit!") + + # Update from API response + policy.update_from_response(metrics, "abc123") + + # Periodic cleanup + evicted = policy.evict_expired() + """ + + def __init__( + self, + default_ttl: TTLTier = TTLTier.SHORT, + default_ttl_seconds: Optional[int] = None, + ): + self._default_ttl = default_ttl + self._default_ttl_seconds = default_ttl_seconds if default_ttl_seconds is not None else default_ttl.seconds + self._entries: Dict[str, CacheEntry] = {} + self._lock = threading.Lock() + + self._total_hits = 0 + self._total_misses = 0 + self._total_evictions = 0 + self._total_additions = 0 + + @property + def default_ttl(self) -> TTLTier: + """Get default TTL tier.""" + return self._default_ttl + + @default_ttl.setter + def default_ttl(self, value: TTLTier): + """Set default TTL tier.""" + self._default_ttl = value + self._default_ttl_seconds = value.seconds + + def add_entry( + self, request_id: str, content_hash: str = "", token_count: int = 0 + ) -> Optional[CacheEntry]: + if not request_id: + return None + ttl_secs = self._default_ttl_seconds + now = time.time() + + with self._lock: + if request_id in self._entries: + entry = self._entries[request_id] + entry.last_accessed_at = now + entry.token_count = token_count or entry.token_count + logger.debug(f"Cache entry refreshed: {request_id}") + else: + entry = CacheEntry( + content_hash=content_hash, + request_id=request_id, + created_at=now, + last_accessed_at=now, + ttl_seconds=ttl_secs, + token_count=token_count, + ) + self._entries[request_id] = entry + self._total_additions += 1 + logger.debug( + f"Cache entry added: {request_id} " + f"(ttl={ttl_secs}s, tokens={token_count})" + ) + + return entry + + def touch_entry(self, request_id: str) -> bool: + with self._lock: + entry = self._entries.get(request_id) + if entry is None: + return False + if entry.is_expired(): + del self._entries[request_id] + self._total_evictions += 1 + return False + entry.last_accessed_at = time.time() + self._total_hits += 1 + return True + + def is_cached(self, request_id: str) -> bool: + with self._lock: + entry = self._entries.get(request_id) + if entry is None: + self._total_misses += 1 + return False + if entry.is_expired(): + del self._entries[request_id] + self._total_evictions += 1 + self._total_misses += 1 + return False + self._total_hits += 1 + return True + + def evict_expired(self) -> List[CacheEntry]: + """ + Remove all entries that have exceeded their TTL. + + Returns: + List of evicted CacheEntry objects + """ + now = time.time() + evicted = [] + + with self._lock: + expired_hashes = [ + h for h, entry in self._entries.items() if entry.is_expired(now) + ] + for h in expired_hashes: + evicted.append(self._entries.pop(h)) + self._total_evictions += len(evicted) + + if evicted: + total_tokens = sum(e.token_count for e in evicted) + logger.info( + f"TTL eviction: removed {len(evicted)} expired entries " + f"({total_tokens:,} tokens)" + ) + + return evicted + + def get_cached_hashes(self) -> Set[str]: + now = time.time() + with self._lock: + return { + entry.content_hash + for entry in self._entries.values() + if not entry.is_expired(now) and entry.content_hash + } + + def get_cached_count(self) -> int: + """Get number of active (non-expired) cache entries.""" + now = time.time() + with self._lock: + return sum( + 1 for entry in self._entries.values() if not entry.is_expired(now) + ) + + def get_total_cached_tokens(self) -> int: + """Get total tokens across all active cache entries.""" + now = time.time() + with self._lock: + return sum( + entry.token_count + for entry in self._entries.values() + if not entry.is_expired(now) + ) + + def update_from_response( + self, metrics: CacheMetrics, request_id: str, content_hash: str = "" + ) -> None: + if not request_id: + return + + if metrics.cache_creation_tokens > 0: + # New or partial cache write — create/update entry with new token count + self.add_entry( + request_id, + content_hash=content_hash, + token_count=metrics.cache_creation_tokens, + ) + logger.debug( + f"Cache write confirmed: {request_id} " + f"({metrics.cache_creation_tokens} tokens cached)" + ) + elif metrics.cache_read_tokens > 0: + # Pure cache hit — just refresh the entry + self.touch_entry(request_id) + logger.debug( + f"Cache hit confirmed: {request_id} " + f"({metrics.cache_read_tokens} tokens read from cache)" + ) + + def reset(self) -> None: + """Clear all entries and reset statistics.""" + with self._lock: + self._entries.clear() + self._total_hits = 0 + self._total_misses = 0 + self._total_evictions = 0 + self._total_additions = 0 + logger.info("TTL eviction policy reset") + + def get_stats(self) -> Dict: + """ + Get cache statistics. + + Returns: + Dictionary with hit/miss counts, active entries, total tokens, etc. + """ + now = time.time() + with self._lock: + active_entries = [ + e for e in self._entries.values() if not e.is_expired(now) + ] + total_tokens = sum(e.token_count for e in active_entries) + total_requests = self._total_hits + self._total_misses + hit_rate = ( + self._total_hits / total_requests * 100 if total_requests > 0 else 0 + ) + + return { + "active_entries": len(active_entries), + "total_entries": len(self._entries), + "total_cached_tokens": total_tokens, + "total_hits": self._total_hits, + "total_misses": self._total_misses, + "total_evictions": self._total_evictions, + "total_additions": self._total_additions, + "hit_rate_pct": round(hit_rate, 2), + "default_ttl": self._default_ttl.value, + "default_ttl_seconds": self._default_ttl_seconds, + } + + def __len__(self): + """Get number of entries (including possibly expired).""" + return len(self._entries) + + def __repr__(self): + active = self.get_cached_count() + return ( + f"TTLEvictionPolicy(active={active}, " + f"total={len(self._entries)}, " + f"default_ttl={self._default_ttl.value})" + ) diff --git a/docs/README.md b/docs/README.md index 6fae7d0..5911c6c 100644 --- a/docs/README.md +++ b/docs/README.md @@ -1,34 +1,42 @@ # ContextPilot Documentation -Welcome to the ContextPilot documentation. This guide covers everything you need to get started and make the most of ContextPilot. - ## Getting Started | Guide | Description | |-------|-------------| | [Installation](getting_started/installation.md) | System requirements and pip install | | [Quick Start](getting_started/quickstart.md) | Your first ContextPilot pipeline in 5 minutes | +| [Docker](getting_started/docker.md) | Container deployment | -## User Guides +## Guides | Guide | Description | |-------|-------------| +| [OpenClaw Integration](guides/openclaw.md) | Proxy setup for OpenClaw agents | +| [How It Works](guides/how_it_works.md) | Reorder and deduplication explained | +| [Cache Synchronization](guides/cache_sync.md) | Self-hosted (eviction callbacks) vs cloud (TTL) | | [Offline Usage](guides/offline_usage.md) | Batch processing without server | -| [Online Usage](guides/online_usage.md) | Index server (stateless & stateful modes) | -| [Engine Integration](guides/online_usage.md#inference-engine-integration) | **Required for stateful mode** — zero-patch eviction callbacks for SGLang, vLLM, and llama.cpp | -| [Multi-Turn Conversations](guides/multi_turn.md) | Context deduplication across turns (30-60% savings) | -| [PageIndex Integration](guides/pageindex.md) | Tree-structured documents → ContextPilot scheduling | -| [mem0 Integration](guides/mem0.md) | LoCoMo benchmark with mem0 memory backend | -| [Docker](guides/docker.md) | All-in-one and standalone container deployment | +| [Online Usage](guides/online_usage.md) | Index server (stateless and stateful modes) | +| [Multi-Turn Conversations](guides/multi_turn.md) | Context deduplication across turns | +| [Mem0 Integration](guides/mem0.md) | Memory-augmented chat with LoCoMo benchmark | +| [PageIndex Integration](guides/pageindex.md) | Tree-structured documents | +| [Mac + llama.cpp](guides/mac_llama_cpp.md) | Apple Silicon deployment | + +## Benchmarks + +| Benchmark | Description | +|-----------|-------------| +| [OpenClaw](benchmarks/openclaw.md) | 60 enterprise document analysis tasks on RTX 5090 | +| [RAG](benchmarks/rag.md) | MultihopRAG and NarrativeQA on Qwen3-32B and DeepSeek-R1 | ## Reference | Document | Description | |----------|-------------| | [API Reference](reference/api.md) | Pipeline, InferenceConfig, HTTP endpoints | -| [Benchmarks](reference/benchmarks.md) | GPU vs CPU performance analysis and methodology | - +| [Benchmarks](reference/benchmarks.md) | GPU vs CPU performance methodology | ## Quick Links - [Examples](../examples/) +- [Paper](https://arxiv.org/abs/2511.03475) diff --git a/docs/benchmarks/openclaw.md b/docs/benchmarks/openclaw.md new file mode 100644 index 0000000..0f6e438 --- /dev/null +++ b/docs/benchmarks/openclaw.md @@ -0,0 +1,54 @@ +# OpenClaw Benchmark + +Evaluation of ContextPilot on [OpenClaw](https://openclaw.ai) agent workloads using the [claw-tasks](https://github.com/EfficientContext/ClawTasks) dataset. + +## Setup + +| Setting | Value | +|---------|-------| +| Model | Qwen3-4B-Instruct-2507 | +| Engine | SGLang 0.5.9 | +| GPU | single RTX 5090 | +| Context Length | 131,072 tokens | +| Dataset | 60 tasks, 22 documents (490 KB), ~250 turns | +| Baseline | OpenClaw → SGLang (direct) | +| Treatment | OpenClaw → ContextPilot → SGLang (proxy) | + +## Results + +``` + Avg P50 P99 +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +Prompt Tokens + OpenClaw + SGLang 45,771 44,570 92,785 + OpenClaw + ContextPilot + SGLang 33,622 32,526 51,581 + Δ -26.5% -27.0% -44.4% + +Wall Time (s) + OpenClaw + SGLang 26.1 25.2 68.8 + OpenClaw + ContextPilot + SGLang 20.8 21.8 50.4 + Δ -20.4% -13.3% -26.6% + +Completion Tokens + OpenClaw + SGLang 765 1004 1024 + OpenClaw + ContextPilot + SGLang 758 981 1024 + Δ -0.9% -2.3% +0.0% + +Accuracy (substantive output) + OpenClaw + SGLang 245/245 (100.0%) + OpenClaw + ContextPilot + SGLang 245/245 (100.0%) +━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ +``` + +## Reproduce + +```bash +git clone https://github.com/EfficientContext/ClawTasks.git +cd ClawTasks +python scripts/run_bench.py --gpu 0 +python scripts/analyze.py results/results.jsonl +``` + +## Raw Data + +To generate raw results, run the benchmark using [claw-tasks](https://github.com/EfficientContext/ClawTasks). Results are saved to `results/results.jsonl` (490 data points: 60 scenarios, each run with and without ContextPilot, ~4 turns each). diff --git a/docs/benchmarks/rag.md b/docs/benchmarks/rag.md new file mode 100644 index 0000000..72719f5 --- /dev/null +++ b/docs/benchmarks/rag.md @@ -0,0 +1,25 @@ +# RAG Benchmark Results + +## Qwen3-32B on 4×A6000 + +Single-node academic RAG with a 32B model on consumer GPUs. + +| Benchmark | Method | Prefill TP (tok/s) | Cache Hit | F1 (%) | +|-----------|--------|--------------------|-----------|--------| +| MultihopRAG | SGLang | 7,290 | 4.64% | 60.42 | +| | **SGLang + ContextPilot** | **14,214** | **33.97%** | **64.39** | +| NarrativeQA | SGLang | 7,921 | 5.91% | 28.41 | +| | **SGLang + ContextPilot** | **12,117** | **20.82%** | **29.64** | + +## DeepSeek-R1-671B on 16×H20 + +Production-scale 671B MoE inference on a multi-node GPU cluster. + +| Benchmark | Method | Prefill TP (tok/s) | Cache Hit | F1 (%) | +|-----------|--------|--------------------|-----------|--------| +| MultihopRAG | SGLang | 9,636 | 5.12% | 64.15 | +| | **SGLang + ContextPilot** | **17,498** | **60.37%** | **64.68** | +| NarrativeQA | SGLang | 8,687 | 6.08% | 40.20 | +| | **SGLang + ContextPilot** | **13,201** | **38.24%** | **41.08** | + +For methodology and full results, see the [paper](https://arxiv.org/abs/2511.03475). diff --git a/docs/guides/cache_sync.md b/docs/guides/cache_sync.md new file mode 100644 index 0000000..a6d13ae --- /dev/null +++ b/docs/guides/cache_sync.md @@ -0,0 +1,85 @@ +# Cache Synchronization + +ContextPilot maintains a Context Index that tracks what content is currently cached in the inference backend. This index drives reordering and dedup decisions. Keeping it in sync with the backend's actual cache state is critical. + +The sync strategy depends on whether you control the inference engine: + +| | Self-hosted (SGLang, vLLM, llama.cpp) | Cloud provider (OpenAI, Anthropic, etc.) | +|--|---------------------------------------|----------------------------------------| +| You deploy the engine | Yes | No | +| API protocol | OpenAI-compatible | OpenAI-compatible | +| Can patch the engine | Yes → eviction callbacks | No → TTL estimation | +| Sync accuracy | Exact | Approximate | + +Both use the same OpenAI-compatible API. The difference is whether ContextPilot can install a hook into the engine's cache eviction path. + +## Self-hosted: Eviction Callbacks + +When you deploy SGLang, vLLM, or llama.cpp yourself, ContextPilot patches the engine's KV cache at runtime to report evictions: + +``` +SGLang Radix Cache evicts an entry + │ + ▼ +_sglang_hook.py intercepts RadixCache.evict() + │ collects evicted request_ids + │ + ▼ +POST /evict {"request_ids": ["req-1", "req-2"]} + │ + ▼ +ContextPilot removes entries from Context Index + │ + ▼ +Next reorder uses updated index (no stale entries) +``` + +This is exact — ContextPilot knows precisely what is and isn't cached. No guessing. + +The hook is installed automatically at import time via `contextpilot_hook.pth`. No engine modification needed. + +## Cloud Provider: TTL Estimation + +When using a cloud provider's API (OpenAI, Anthropic, MiniMax), you can't patch the engine. These providers cache prompts with a TTL window but provide no eviction callback. ContextPilot models the cache state locally: + +``` +Request sent to cloud API + │ + ▼ +Response received + │ + ├─ cache_read_tokens > 0 → cache hit confirmed + │ ├─ TTL timer refreshed + │ └─ Context Index node access time updated + │ + └─ cache_creation_tokens > 0 → new cache entry + └─ TTL timer started + + ...time passes... + + │ + ▼ +TTL expires (~5 min Anthropic, ~5-10 min OpenAI) + │ + ▼ +Entry removed from TTL tracker + │ + ▼ +Next request: ContextPilot no longer marks this content for caching +``` + +The worst case is a missed cache hint (ContextPilot thinks content expired when it's still cached). This means one request won't get the `cache_control` marker — but the cloud may still cache-hit on its own. It never causes incorrect behavior. + +### Per-Request TTL + +Each request gets its own TTL entry, even if multiple requests share the same prefix. This is important: + +``` +req-001: prefix [A, B, C] → TTL entry for req-001 +req-002: prefix [A, B, D] → TTL entry for req-002 + +req-001 expires → only req-001 removed from TTL +req-002 still active → prefix [A, B] still considered cached +``` + +If req-001 and req-002 share prefix [A, B], evicting req-001 doesn't affect req-002's TTL. The shared prefix stays in the cache model as long as any request using it is alive. diff --git a/docs/guides/how_it_works.md b/docs/guides/how_it_works.md new file mode 100644 index 0000000..47aad58 --- /dev/null +++ b/docs/guides/how_it_works.md @@ -0,0 +1,88 @@ +# How It Works + +ContextPilot optimizes LLM inference through two mechanisms: **Reorder** and **Deduplication**. Both operate on the request before it reaches the inference engine. + +## Reorder + +LLM engines (SGLang, vLLM) use a prefix cache (Radix Cache) — if two requests share the same token prefix, the second request reuses the cached KV computation. But when context blocks are assembled in different orders across requests, the prefix changes and the cache misses. + +Reorder solves this by sorting context blocks into a canonical order that maximizes prefix sharing: + +``` +Without reorder: With reorder: + Request 1: [A, B, C] → cached Request 1: [A, B, C] → cached + Request 2: [D, A, B] → cache miss Request 2: [A, B, D] → prefix [A, B] hit! + (prefix D≠A, no match) (ContextPilot moves cached A, B to front) +``` + +ContextPilot builds a Context Index (hierarchical clustering tree) that groups similar documents. When a new request arrives, it reorders the documents so that: +1. Documents already in the cache come first (maximizing prefix reuse) +2. Similar documents are adjacent (maximizing future prefix sharing) + +See [cache_sync.md](cache_sync.md) for how the Context Index stays in sync with the engine's cache. + +## Deduplication + +When an agent reads multiple documents that share content, the conversation history accumulates redundant text. ContextPilot removes this redundancy through two layers: + +### ContextBlock-level deduplication + +If a tool result is byte-identical to an earlier one in the same conversation, replace it with a reference. This is handled by the intercept pipeline's `single_doc_hashes` for cross-turn deduplication, and the conversation tracker's `deduplicate()` for the `/reorder` API. + +### Content-level deduplication + +Like file system deduplication — when two documents share content blocks (e.g., contracts from the same template), only the first occurrence is kept. Subsequent identical blocks are replaced with pointers. + +**How it works:** + +1. Split each tool result into blocks using content-defined boundaries (line hash mod M) +2. Hash each block (SHA-256) +3. If a block matches one from a different tool result, replace it with a pointer +4. Never deduplicate within the same tool result + +``` +Contract A (kept intact): Contract B (after deduplication): +┌────────────────────────┐ ┌────────────────────────┐ +│ Art. 1 — Definitions │ │ Art. 1 — (unique part) │ +│ Art. 2 — Scope │ │ [... "Art. 2 — Scope" │ +│ Art. 3 — Term │ │ — see earlier result] │ +│ ... │ │ ... │ +│ Art. 16 — General │ │ [... "Art. 16" │ +│ Art. 17 — Cloud Terms │ │ — see earlier result] │ +└────────────────────────┘ │ Art. 17 — AI Terms │ + 45 KB └────────────────────────┘ + 15 KB +``` + +Each pointer quotes the first line of the replaced block so the LLM knows what content it refers to. The LLM resolves pointers via attention to the original content above. + +**Why content-defined chunking?** Fixed-size blocks have an alignment problem — if content shifts by a few lines, all block boundaries change and hashes stop matching. Content-defined boundaries (determined by `hash(line) % M`) adapt to the content, so the same text produces the same blocks regardless of its position in the document. This is the same principle used in file system deduplication (Rabin fingerprint). + +### Tuning block size + +The `--chunk-modulus M` flag controls the average block size (default: 13 lines per block). + +```bash +python -m contextpilot.server.http_server --chunk-modulus 13 # default +``` + +| M | Avg block size | Best for | +|---|---------------|----------| +| 7-10 | ~7-10 lines | Documents with scattered differences (e.g., config files, code with inline changes) | +| 11-15 | ~11-15 lines | Template documents with concentrated differences (contracts, proposals) — **default** | +| 20-30 | ~20-30 lines | Documents that are nearly identical (only a few lines differ) | + +Smaller M = more blocks = more fine-grained deduplication, but each pointer has ~80 chars of overhead. Larger M = fewer blocks = less overhead, but may miss partial overlaps if differences are scattered. + +### API + +```python +from contextpilot.dedup import dedup_chat_completions, DedupResult + +result = dedup_chat_completions(body, chunk_modulus=13) +# result.blocks_deduped — number of blocks replaced with pointers +# result.blocks_total — total blocks processed +# result.chars_saved — characters removed +``` + +The deduplication module (`contextpilot/dedup/`) is independent of the server — it operates on message content with no dependency on the Context Index or cache state. diff --git a/docs/guides/mem0.md b/docs/guides/mem0.md index 0ed8887..45f70e1 100644 --- a/docs/guides/mem0.md +++ b/docs/guides/mem0.md @@ -58,7 +58,7 @@ python examples/mem0_locomo_example.py ## Results -Aggregate across all 10 LoCoMo conversations, Qwen2.5-7B-Instruct on 2xA6000 (SGLang, tp=2): +Aggregate across all 10 LoCoMo conversations, Qwen3-4B on A6000 (SGLang, tp=1): | k | mode | ttft | ttft delta | judge | |---|---|---|---|---| diff --git a/docs/guides/openclaw.md b/docs/guides/openclaw.md new file mode 100644 index 0000000..fc19202 --- /dev/null +++ b/docs/guides/openclaw.md @@ -0,0 +1,277 @@ +# ContextPilot + OpenClaw Integration Guide + +## Architecture + +
+OpenClaw + ContextPilot Pipeline +
+ +ContextPilot acts as a transparent HTTP proxy. OpenClaw sends requests to the proxy instead of directly to the LLM API. The proxy deduplicates shared content across tool results, reorders documents, and forwards to the backend. + +## Why This Matters for OpenClaw + +OpenClaw's search and memory retrieval results appear as **tool_result messages** in the conversation history, not in the system prompt. When multiple search results are returned, their ordering affects the LLM's attention and response quality. + +ContextPilot: +1. **Reorder**: Reorders documents within tool results to maximize prefix cache hits (multi-doc tool results) +2. **Dedup**: ContextBlock-level and content-level deduplication across tool results — identical content replaced with back-references, reducing prefill tokens + +Results from reorder and dedup are cached and reapplied on subsequent turns to keep the prefix consistent across the conversation (prefix cache alignment). See [Cache Synchronization](cache_sync.md) for how ContextPilot stays in sync with the inference engine's cache. + +## Setup + +### Quick Start (one command) + +```bash +# Clone and run +git clone https://github.com/EfficientContext/ContextPilot.git +cd ContextPilot/examples/openclaw +bash setup.sh anthropic # or: bash setup.sh openai +``` + +The script installs ContextPilot, generates a config, and starts the proxy. + +### Docker + +```bash +cd ContextPilot/examples/openclaw +docker compose up -d + +# OpenAI instead of Anthropic: +CONTEXTPILOT_BACKEND_URL=https://api.openai.com docker compose up -d +``` + +### Manual + +```bash +pip install contextpilot +python -m contextpilot.server.http_server \ + --stateless --port 8765 \ + --infer-api-url https://api.anthropic.com +``` + +## Configure OpenClaw + +### Option A: UI (recommended) + +1. Open OpenClaw +2. Go to **Settings → Models** +3. Add a custom provider: + +| Field | Value | +|-------|-------| +| Name | `contextpilot-anthropic` | +| Base URL | `http://localhost:8765/v1` | +| API Key | your Anthropic API key | +| API | `anthropic-messages` | +| Model ID | `claude-opus-4-6` | + +4. Select the model and start chatting + +### Option B: Config file + +Merge into `~/.openclaw/openclaw.json`: + +```json +{ + "models": { + "providers": { + "contextpilot-anthropic": { + "baseUrl": "http://localhost:8765/v1", + "apiKey": "${ANTHROPIC_API_KEY}", + "api": "anthropic-messages", + "headers": { + "X-ContextPilot-Scope": "all" + }, + "models": [ + { + "id": "claude-opus-4-6", + "name": "Claude Opus 4.6 (via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 200000, + "maxTokens": 32000 + } + ] + } + } + } +} +``` + +For OpenAI, use `api: "openai-completions"` and point `--infer-api-url` to `https://api.openai.com`. See `examples/openclaw/openclaw.json.example` for both providers. + +### Option C: Self-hosted model via SGLang + +For self-hosted models, ContextPilot proxies between OpenClaw and SGLang: + +``` +OpenClaw ──▶ ContextPilot Proxy (server:8765) ──▶ SGLang (server:30000) +``` + +Start SGLang with tool calling support: + +```bash +python -m sglang.launch_server \ + --model-path Qwen/Qwen3.5-27B \ + --tool-call-parser qwen3_coder \ + --port 30000 +``` + +Start ContextPilot proxy: + +```bash +python -m contextpilot.server.http_server \ + --port 8765 \ + --infer-api-url http://localhost:30000 \ + --model Qwen/Qwen3.5-27B +``` + +Configure OpenClaw (replace `` with your server's IP): + +```bash +# Requires jq (install: sudo apt install jq / brew install jq) +jq ' + .agents.defaults.model.primary = "contextpilot-sglang/Qwen/Qwen3.5-27B" | + .models = { + "mode": "merge", + "providers": { + "contextpilot-sglang": { + "baseUrl": "http://:8765/v1", + "apiKey": "placeholder", + "api": "openai-completions", + "headers": {"X-ContextPilot-Scope": "all"}, + "models": [{ + "id": "Qwen/Qwen3.5-27B", + "name": "Qwen 3.5 27B (SGLang via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 131072, + "maxTokens": 8192 + }] + } + } + } +' ~/.openclaw/openclaw.json > /tmp/oc.json && mv /tmp/oc.json ~/.openclaw/openclaw.json +``` + +Then restart: + +```bash +pkill -f openclaw && openclaw gateway start && openclaw tui +``` + +
+Without jq: manually edit ~/.openclaw/openclaw.json + +1. Change `agents.defaults.model.primary` to `"contextpilot-sglang/Qwen/Qwen3.5-27B"` +2. Add a `"models"` key at the top level: + +```json +"models": { + "mode": "merge", + "providers": { + "contextpilot-sglang": { + "baseUrl": "http://:8765/v1", + "apiKey": "placeholder", + "api": "openai-completions", + "headers": { "X-ContextPilot-Scope": "all" }, + "models": [{ + "id": "Qwen/Qwen3.5-27B", + "name": "Qwen 3.5 27B (SGLang via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 131072, + "maxTokens": 8192 + }] + } + } +} +``` + +
+ +> **Important**: Use the server's IP address (not hostname) in `baseUrl` to avoid IPv6 DNS resolution issues in Node.js/WSL environments. `--tool-call-parser` is required for OpenClaw's tool loop to work. + +## Verify + +Check the `X-ContextPilot-Result` response header: + +``` +X-ContextPilot-Result: {"intercepted":true,"documents_reordered":true,"total_documents":8,"sources":{"system":1,"tool_results":2}} +``` + +If the header is absent, the request had fewer than 2 extractable documents (nothing to reorder). + +## Document Extraction + +ContextPilot auto-detects these formats in both system prompts and tool results: + +| Format | Pattern | Typical Source | +|--------|---------|----------------| +| XML tags | `...` | RAG systems | +| File tags | `...` | Code search | +| Numbered | `[1] doc [2] doc` | Search rankings | +| Separator | docs split by `---` or `===` | Text chunking | +| Markdown headers | sections split by `#`/`##` | Structured docs | + +Auto-detection priority: XML > Numbered > Separator > Markdown headers. + +## Scope Control + +| `X-ContextPilot-Scope` | System Prompt | Tool Results | +|:---:|:---:|:---:| +| `all` (default) | reordered | reordered | +| `system` | reordered | untouched | +| `tool_results` | untouched | reordered | + +Set via headers in the OpenClaw provider config, or per-request. + +## Full Header Reference + +| Header | Description | Default | +|--------|-------------|---------| +| `X-ContextPilot-Enabled` | Enable/disable | `true` | +| `X-ContextPilot-Mode` | Extraction mode | `auto` | +| `X-ContextPilot-Scope` | Which messages to process | `all` | +| `X-ContextPilot-Tag` | Custom XML tag name | `document` | +| `X-ContextPilot-Separator` | Custom separator | `---` | +| `X-ContextPilot-Alpha` | Clustering distance parameter | `0.001` | +| `X-ContextPilot-Linkage` | Clustering linkage method | `average` | + +For details on how reorder and dedup work, see [How It Works](how_it_works.md). + +## Benchmark Results + +Tested on [claw-tasks](https://github.com/EfficientContext/ClawTasks) — 60 enterprise document analysis tasks, 22 documents (490 KB), ~250 turns. + +``` + Avg P50 P99 +Prompt Tokens + OpenClaw + SGLang 45,771 44,570 92,785 + OpenClaw + ContextPilot + SGLang 33,622 32,526 51,581 + Δ -26.5% -27.0% -44.4% + +Wall Time (s) + OpenClaw + SGLang 26.1 25.2 68.8 + OpenClaw + ContextPilot + SGLang 20.8 21.8 50.4 + Δ -20.4% -13.3% -26.6% + +Accuracy 245/245 245/245 +``` + +See [`docs/benchmarks/openclaw.md`](../benchmarks/openclaw.md) for details. + +## Troubleshooting + +**No `X-ContextPilot-Result` header** — Request had < 2 extractable documents. Check that search/memory tools are returning multiple results. + +**Connection refused** — Proxy not running. Check `curl http://localhost:8765/health`. + +**`Connection error.` from OpenClaw (Node.js)** — IPv6 DNS resolution failure. Use IP address in `baseUrl`, or `export NODE_OPTIONS="--dns-result-order=ipv4first"`. + +**401/403 from backend** — API key not set or invalid. The proxy forwards auth headers as-is. + +**Tool call appears as XML text, agent stops** — SGLang not parsing tool calls into structured `tool_calls`. Add `--tool-call-parser qwen3_coder` (or the appropriate parser for your model) to SGLang launch command. + +**Tool results not reordered** — Check scope is `all` or `tool_results`. Verify tool results use a supported format. diff --git a/docs/images/openclaw-cp.png b/docs/images/openclaw-cp.png new file mode 100644 index 0000000..b5823a9 Binary files /dev/null and b/docs/images/openclaw-cp.png differ diff --git a/examples/openclaw/Dockerfile b/examples/openclaw/Dockerfile new file mode 100644 index 0000000..6e19b30 --- /dev/null +++ b/examples/openclaw/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.12-slim + +WORKDIR /app + +RUN pip install --no-cache-dir contextpilot + +EXPOSE 8765 + +ENTRYPOINT ["python", "-m", "contextpilot.server.http_server", "--stateless"] +CMD ["--port", "8765", "--infer-api-url", "https://api.anthropic.com"] diff --git a/examples/openclaw/README.md b/examples/openclaw/README.md new file mode 100644 index 0000000..d497ec8 --- /dev/null +++ b/examples/openclaw/README.md @@ -0,0 +1,219 @@ +# ContextPilot + OpenClaw Quick Start + +## Flow + +``` +1. Start ContextPilot proxy (one command) +2. OpenClaw UI → Settings → Models (add custom provider pointing to proxy) +3. Select model, start chatting (done) +``` + +``` +OpenClaw UI ──▶ ContextPilot Proxy (localhost:8765) ──▶ LLM API (Anthropic/OpenAI) + ◀──────────────── responses flow back ◀────────────────── +``` + +## One-Click Setup + +### Option A: Shell script + +```bash +# Anthropic (default) +bash setup.sh anthropic + +# OpenAI +bash setup.sh openai +``` + +The script will: +1. Check Python, install ContextPilot +2. Generate the provider config JSON +3. Print what to enter in OpenClaw UI +4. Start the proxy + +### Option B: Docker Compose + +```bash +# Anthropic (default) +docker compose up -d + +# OpenAI +CONTEXTPILOT_BACKEND_URL=https://api.openai.com docker compose up -d +``` + +## Manual Setup + +### Step 1: Start proxy + +```bash +pip install contextpilot + +python -m contextpilot.server.http_server \ + --stateless --port 8765 \ + --infer-api-url https://api.anthropic.com +``` + +### Step 2: Add provider in OpenClaw + +Open OpenClaw UI → **Settings** → **Models** → add custom provider: + +| Field | Value | +|-------|-------| +| Base URL | `http://localhost:8765/v1` | +| API Key | your Anthropic/OpenAI key | +| API | `anthropic-messages` or `openai-completions` | +| Headers | `X-ContextPilot-Scope: all` | + +Or merge `openclaw.json.example` into `~/.openclaw/openclaw.json`: + +```json +{ + "models": { + "providers": { + "contextpilot-anthropic": { + "baseUrl": "http://localhost:8765/v1", + "apiKey": "${ANTHROPIC_API_KEY}", + "api": "anthropic-messages", + "headers": { "X-ContextPilot-Scope": "all" }, + "models": [{ "id": "claude-opus-4-6", "name": "Claude Opus 4.6 (via ContextPilot)" }] + } + } + } +} +``` + +### Step 3: Select model and chat + +In OpenClaw, select the model from the ContextPilot provider. Search/memory results in tool_results will be automatically reordered. + +### Step 4: Verify + +Check the `X-ContextPilot-Result` response header for metadata: + +``` +X-ContextPilot-Result: {"intercepted":true,"documents_reordered":true,"total_documents":5,"sources":{"system":1,"tool_results":1}} +``` + +## Using with SGLang (Self-Hosted Models) + +For self-hosted models served by SGLang, ContextPilot acts as a proxy between OpenClaw and SGLang, enabling automatic document reordering and KV cache tracking. + +``` +OpenClaw (WSL/local) ──▶ ContextPilot Proxy (server:8765) ──▶ SGLang (server:30000) +``` + +### Step 1: Start SGLang with tool calling support + +```bash +python -m sglang.launch_server \ + --model-path Qwen/Qwen3.5-27B \ + --tool-call-parser qwen3_coder \ + --port 30000 +``` + +> `--tool-call-parser` is required for OpenClaw's tool loop to work. Without it, tool calls are output as plain text and the agent loop won't continue. + +### Step 2: Start ContextPilot proxy + +```bash +python -m contextpilot.server.http_server \ + --port 8765 \ + --infer-api-url http://localhost:30000 \ + --model Qwen/Qwen3.5-27B +``` + +### Step 3: Configure OpenClaw + +Patch your OpenClaw config (replace `` with your server's IP): + +```bash +# Requires jq (install: sudo apt install jq / brew install jq) +jq ' + .agents.defaults.model.primary = "contextpilot-sglang/Qwen/Qwen3.5-27B" | + .models = { + "mode": "merge", + "providers": { + "contextpilot-sglang": { + "baseUrl": "http://:8765/v1", + "apiKey": "placeholder", + "api": "openai-completions", + "headers": {"X-ContextPilot-Scope": "all"}, + "models": [{ + "id": "Qwen/Qwen3.5-27B", + "name": "Qwen 3.5 27B (SGLang via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 131072, + "maxTokens": 8192 + }] + } + } + } +' ~/.openclaw/openclaw.json > /tmp/oc.json && mv /tmp/oc.json ~/.openclaw/openclaw.json +``` + +Then restart OpenClaw: + +```bash +pkill -f openclaw && openclaw gateway start && openclaw tui +``` + +
+Without jq: manually add this JSON to ~/.openclaw/openclaw.json + +Add a `"models"` key at the top level and change `agents.defaults.model.primary`: + +```json +{ + "agents": { + "defaults": { + "model": { + "primary": "contextpilot-sglang/Qwen/Qwen3.5-27B" + } + } + }, + "models": { + "mode": "merge", + "providers": { + "contextpilot-sglang": { + "baseUrl": "http://:8765/v1", + "apiKey": "placeholder", + "api": "openai-completions", + "headers": { "X-ContextPilot-Scope": "all" }, + "models": [ + { + "id": "Qwen/Qwen3.5-27B", + "name": "Qwen 3.5 27B (SGLang via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 131072, + "maxTokens": 8192 + } + ] + } + } + } +} +``` + +
+ +> **Important**: Use the server's IP address (not hostname) in `baseUrl` to avoid IPv6 DNS resolution issues in Node.js/WSL environments. + +### Troubleshooting + +| Symptom | Cause | Fix | +|---------|-------|-----| +| `Connection error.` | IPv6 DNS resolution fails in Node.js | Use IP address in `baseUrl`, or `export NODE_OPTIONS="--dns-result-order=ipv4first"` | +| Tool call appears as XML text, agent stops | SGLang not parsing tool calls | Add `--tool-call-parser qwen3_coder` to SGLang launch command | +| `Invalid JSON body` | Multiline curl command broken by shell | Use single-line JSON in curl | + +## Scope Control + +| Header value | System prompt | Tool results | +|:---:|:---:|:---:| +| `all` (default) | reordered | reordered | +| `system` | reordered | untouched | +| `tool_results` | untouched | reordered | + +See [SKILL.md](SKILL.md) for the full header reference. diff --git a/examples/openclaw/SKILL.md b/examples/openclaw/SKILL.md new file mode 100644 index 0000000..8024a7c --- /dev/null +++ b/examples/openclaw/SKILL.md @@ -0,0 +1,47 @@ +--- +name: contextpilot +description: Optimize document ordering in LLM context for better retrieval performance +version: 1.0.0 +triggers: + - /contextpilot + - /cp +--- + +# ContextPilot Integration + +You have ContextPilot enabled as a transparent proxy. All your LLM API requests are automatically routed through ContextPilot, which reorders documents in your context window for optimal retrieval performance. + +## What ContextPilot Does + +When your request contains multiple documents (in system prompts or tool results), ContextPilot: + +1. **Extracts** documents from XML tags (``, ``, etc.), numbered lists, separators, or markdown headers +2. **Clusters** documents by semantic similarity using hierarchical clustering +3. **Reorders** documents to minimize attention distance between related content +4. **Reconstructs** the request preserving the original format + +## Supported Document Formats + +- XML tags: `...`, `...` +- Numbered: `[1] doc [2] doc [3] doc` +- Separator: docs separated by `---` or `===` +- Markdown headers: sections split by `#` or `##` headers + +## How to Verify + +Check the `X-ContextPilot-Result` response header: + +``` +X-ContextPilot-Result: {"intercepted":true,"documents_reordered":true,"total_documents":5,"sources":{"system":1,"tool_results":1}} +``` + +## Configuration Headers + +Send these headers with your API requests to control behavior: + +| Header | Values | Default | +|--------|--------|---------| +| `X-ContextPilot-Enabled` | `true`/`false` | `true` | +| `X-ContextPilot-Mode` | `auto`/`xml_tag`/`numbered`/`separator`/`markdown_header` | `auto` | +| `X-ContextPilot-Scope` | `all`/`system`/`tool_results` | `all` | +| `X-ContextPilot-Alpha` | float | `0.001` | diff --git a/examples/openclaw/docker-compose.yml b/examples/openclaw/docker-compose.yml new file mode 100644 index 0000000..585f08f --- /dev/null +++ b/examples/openclaw/docker-compose.yml @@ -0,0 +1,18 @@ +services: + contextpilot: + build: . + ports: + - "${CONTEXTPILOT_PORT:-8765}:8765" + environment: + - CONTEXTPILOT_INFER_API_URL=${CONTEXTPILOT_BACKEND_URL:-https://api.anthropic.com} + command: + - "--port" + - "8765" + - "--infer-api-url" + - "${CONTEXTPILOT_BACKEND_URL:-https://api.anthropic.com}" + restart: unless-stopped + healthcheck: + test: ["CMD", "python", "-c", "import urllib.request; urllib.request.urlopen('http://localhost:8765/health')"] + interval: 30s + timeout: 5s + retries: 3 diff --git a/examples/openclaw/openclaw.json.example b/examples/openclaw/openclaw.json.example new file mode 100644 index 0000000..e102010 --- /dev/null +++ b/examples/openclaw/openclaw.json.example @@ -0,0 +1,50 @@ +{ + "models": { + "providers": { + "contextpilot-anthropic": { + "baseUrl": "http://localhost:8765/v1", + "apiKey": "${ANTHROPIC_API_KEY}", + "api": "anthropic-messages", + "headers": { + "X-ContextPilot-Scope": "all" + }, + "models": [ + { + "id": "claude-opus-4-6", + "name": "Claude Opus 4.6 (via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 200000, + "maxTokens": 32000 + }, + { + "id": "claude-sonnet-4-5", + "name": "Claude Sonnet 4.5 (via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 200000, + "maxTokens": 16000 + } + ] + }, + "contextpilot-openai": { + "baseUrl": "http://localhost:8765/v1", + "apiKey": "${OPENAI_API_KEY}", + "api": "openai-completions", + "headers": { + "X-ContextPilot-Scope": "all" + }, + "models": [ + { + "id": "gpt-4o", + "name": "GPT-4o (via ContextPilot)", + "reasoning": false, + "input": ["text"], + "contextWindow": 128000, + "maxTokens": 16384 + } + ] + } + } + } +} diff --git a/examples/openclaw/setup.sh b/examples/openclaw/setup.sh new file mode 100755 index 0000000..a63d99f --- /dev/null +++ b/examples/openclaw/setup.sh @@ -0,0 +1,127 @@ +#!/usr/bin/env bash +set -euo pipefail + +# ContextPilot + OpenClaw one-click setup +# Usage: bash setup.sh [anthropic|openai] + +PROVIDER="${1:-anthropic}" +PORT="${CONTEXTPILOT_PORT:-8765}" + +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' + +info() { echo -e "${GREEN}[INFO]${NC} $*"; } +warn() { echo -e "${YELLOW}[WARN]${NC} $*"; } +error() { echo -e "${RED}[ERROR]${NC} $*"; exit 1; } + +# ── Check Python ─────────────────────────────────────────────────────────── +info "Checking Python version..." +if ! command -v python3 &>/dev/null; then + error "Python 3 not found. Install Python 3.10+ first." +fi + +PY_VER=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') +PY_MAJOR=$(echo "$PY_VER" | cut -d. -f1) +PY_MINOR=$(echo "$PY_VER" | cut -d. -f2) + +if [ "$PY_MAJOR" -lt 3 ] || { [ "$PY_MAJOR" -eq 3 ] && [ "$PY_MINOR" -lt 10 ]; }; then + error "Python 3.10+ required, found $PY_VER" +fi +info "Python $PY_VER OK" + +# ── Install ContextPilot ────────────────────────────────────────────────── +info "Installing ContextPilot..." +if python3 -c "import contextpilot" 2>/dev/null; then + info "ContextPilot already installed, upgrading..." + pip install --upgrade contextpilot -q +else + pip install contextpilot -q +fi + +# ── Determine backend URL and API type ──────────────────────────────────── +case "$PROVIDER" in + anthropic) + BACKEND_URL="https://api.anthropic.com" + API_KEY_VAR="ANTHROPIC_API_KEY" + API_TYPE="anthropic-messages" + MODEL_ID="claude-opus-4-6" + MODEL_NAME="Claude Opus 4.6 (via ContextPilot)" + CTX_WINDOW=200000 + MAX_TOKENS=32000 + ;; + openai) + BACKEND_URL="https://api.openai.com" + API_KEY_VAR="OPENAI_API_KEY" + API_TYPE="openai-completions" + MODEL_ID="gpt-4o" + MODEL_NAME="GPT-4o (via ContextPilot)" + CTX_WINDOW=128000 + MAX_TOKENS=16384 + ;; + *) + error "Unknown provider: $PROVIDER. Use 'anthropic' or 'openai'." + ;; +esac + +# ── Check API key ───────────────────────────────────────────────────────── +if [ -z "${!API_KEY_VAR:-}" ]; then + warn "$API_KEY_VAR not set. The proxy will start but requests will fail without a valid key." + warn "Set it with: export $API_KEY_VAR=your-key" +fi + +# ── Generate OpenClaw provider config ───────────────────────────────────── +OPENCLAW_DIR="$HOME/.openclaw" +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PATCH_FILE="$SCRIPT_DIR/contextpilot-provider.json" + +cat > "$PATCH_FILE" < $BACKEND_URL" +info "Press Ctrl+C to stop." +echo "" + +exec python3 -m contextpilot.server.http_server \ + --stateless \ + --port "$PORT" \ + --infer-api-url "$BACKEND_URL" diff --git a/pyproject.toml b/pyproject.toml index 9757c7f..a1abfef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "contextpilot" -version = "0.3.5.post2" +version = "0.4.0" description = "Fast Long-Context Inference via Context Reuse" readme = "README.md" requires-python = ">=3.10" @@ -68,7 +68,7 @@ filterwarnings = [ ] [tool.bumpver] -current_version = "0.2.0" +current_version = "0.4.0" version_pattern = "MAJOR.MINOR.PATCH" commit_message = "Bump version {old_version} -> {new_version}" commit = true diff --git a/tests/test_cloud_adapters.py b/tests/test_cloud_adapters.py new file mode 100644 index 0000000..1cdc0c3 --- /dev/null +++ b/tests/test_cloud_adapters.py @@ -0,0 +1,284 @@ +"""Tests for cloud provider adapters.""" + +import copy + +import pytest + +from contextpilot.server.cloud_adapters import ( + get_cloud_adapter, + AnthropicAdapter, + OpenAIAdapter, + MiniMaxAdapter, + CacheMetrics, + TTLTier, +) + + +class TestGetCloudAdapter: + def test_anthropic(self): + adapter = get_cloud_adapter("anthropic") + assert isinstance(adapter, AnthropicAdapter) + + def test_openai(self): + adapter = get_cloud_adapter("openai") + assert isinstance(adapter, OpenAIAdapter) + + def test_minimax(self): + adapter = get_cloud_adapter("minimax") + assert isinstance(adapter, MiniMaxAdapter) + + def test_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown cloud provider"): + get_cloud_adapter("unknown") + + +class TestAnthropicAdapter: + @pytest.fixture + def adapter(self): + return AnthropicAdapter() + + def test_provider_name(self, adapter): + assert adapter.provider_name == "anthropic" + + def test_api_url(self, adapter): + assert ( + adapter.get_api_url("/v1/messages") + == "https://api.anthropic.com/v1/messages" + ) + + def test_target_path(self, adapter): + assert adapter.get_target_path() == "/v1/messages" + + def test_default_ttl_seconds(self, adapter): + assert adapter.get_default_ttl_seconds() == 300 + + def test_extended_ttl_seconds(self, adapter): + assert adapter.get_extended_ttl_seconds() == 3600 + + def test_auth_headers(self, adapter): + headers = adapter.get_auth_headers("sk-ant-test-key") + assert headers["x-api-key"] == "sk-ant-test-key" + assert "anthropic-version" in headers + assert headers["content-type"] == "application/json" + + def test_inject_cache_control_string_system(self, adapter): + body = { + "system": "You are a helpful assistant.", + "messages": [{"role": "user", "content": "hi"}], + } + result = adapter.inject_cache_control(body, set()) + assert isinstance(result["system"], list) + assert result["system"][0]["type"] == "text" + assert result["system"][0]["text"] == "You are a helpful assistant." + assert result["system"][0]["cache_control"] == {"type": "ephemeral"} + + def test_inject_cache_control_list_system(self, adapter): + body = { + "system": [ + {"type": "text", "text": "First block"}, + {"type": "text", "text": "Second block"}, + ], + "messages": [], + } + result = adapter.inject_cache_control(body, set()) + assert "cache_control" not in result["system"][0] + assert result["system"][1]["cache_control"] == {"type": "ephemeral"} + + def test_inject_cache_control_no_system(self, adapter): + body = {"messages": [{"role": "user", "content": "hi"}]} + result = adapter.inject_cache_control(body, set()) + assert "system" not in result + + def test_inject_cache_control_tool_result_large(self, adapter): + large_content = "x" * 2000 + body = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "t1", + "content": large_content, + } + ], + } + ] + } + result = adapter.inject_cache_control(body, set()) + tool_block = result["messages"][0]["content"][0] + assert tool_block["cache_control"] == {"type": "ephemeral"} + + def test_inject_cache_control_tool_result_small_unchanged(self, adapter): + body = { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "t1", + "content": "short", + } + ], + } + ] + } + result = adapter.inject_cache_control(body, set()) + tool_block = result["messages"][0]["content"][0] + assert "cache_control" not in tool_block + + def test_inject_does_not_mutate_original(self, adapter): + body = { + "system": "test", + "messages": [{"role": "user", "content": "hi"}], + } + original = copy.deepcopy(body) + adapter.inject_cache_control(body, set()) + assert body == original + + def test_parse_cache_metrics(self, adapter): + response = { + "usage": { + "cache_creation_input_tokens": 1000, + "cache_read_input_tokens": 500, + "input_tokens": 1500, + "output_tokens": 200, + } + } + metrics = adapter.parse_cache_metrics(response) + assert metrics.cache_creation_tokens == 1000 + assert metrics.cache_read_tokens == 500 + assert metrics.input_tokens == 1500 + assert metrics.output_tokens == 200 + + def test_parse_cache_metrics_empty_usage(self, adapter): + metrics = adapter.parse_cache_metrics({}) + assert metrics.cache_creation_tokens == 0 + assert metrics.cache_read_tokens == 0 + + +class TestOpenAIAdapter: + @pytest.fixture + def adapter(self): + return OpenAIAdapter() + + def test_provider_name(self, adapter): + assert adapter.provider_name == "openai" + + def test_api_url(self, adapter): + assert ( + adapter.get_api_url("/v1/chat/completions") + == "https://api.openai.com/v1/chat/completions" + ) + + def test_target_path(self, adapter): + assert adapter.get_target_path() == "/v1/chat/completions" + + def test_default_ttl_seconds(self, adapter): + assert adapter.get_default_ttl_seconds() == 3600 + + def test_extended_ttl_seconds(self, adapter): + assert adapter.get_extended_ttl_seconds() == 86400 + + def test_auth_headers(self, adapter): + headers = adapter.get_auth_headers("sk-test-key") + assert headers["Authorization"] == "Bearer sk-test-key" + assert headers["Content-Type"] == "application/json" + + def test_inject_no_retention_by_default(self, adapter): + body = { + "model": "gpt-4o", + "messages": [{"role": "user", "content": "hello"}], + } + result = adapter.inject_cache_control(body, set()) + assert "prompt_cache_retention" not in result + + def test_inject_24h_when_extended(self, adapter): + adapter.configured_ttl = TTLTier.LONG + body = { + "model": "gpt-4.1", + "messages": [{"role": "user", "content": "hello"}], + } + result = adapter.inject_cache_control(body, set()) + assert result["prompt_cache_retention"] == "24h" + + def test_parse_cache_metrics_with_cached_tokens(self, adapter): + response = { + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 200, + "prompt_tokens_details": {"cached_tokens": 800}, + } + } + metrics = adapter.parse_cache_metrics(response) + assert metrics.cache_read_tokens == 800 + assert metrics.cache_creation_tokens == 200 + assert metrics.input_tokens == 1000 + assert metrics.output_tokens == 200 + + def test_parse_cache_metrics_no_cached(self, adapter): + response = {"usage": {"prompt_tokens": 500, "completion_tokens": 100}} + metrics = adapter.parse_cache_metrics(response) + assert metrics.cache_read_tokens == 0 + assert metrics.cache_creation_tokens == 0 + assert metrics.input_tokens == 500 + + def test_parse_cache_metrics_empty(self, adapter): + metrics = adapter.parse_cache_metrics({}) + assert metrics.cache_read_tokens == 0 + + +class TestMiniMaxAdapter: + @pytest.fixture + def adapter(self): + return MiniMaxAdapter() + + def test_provider_name(self, adapter): + assert adapter.provider_name == "minimax" + + def test_api_url(self, adapter): + url = adapter.get_api_url("/v1/messages") + assert "minimax.io" in url + + def test_target_path(self, adapter): + assert adapter.get_target_path() == "/v1/messages" + + def test_default_ttl_seconds(self, adapter): + assert adapter.get_default_ttl_seconds() == 300 + + def test_no_extended_cache(self, adapter): + assert adapter.get_extended_ttl_seconds() is None + assert not adapter.supports_extended_cache + + def test_auth_headers(self, adapter): + headers = adapter.get_auth_headers("mm-key-123") + assert headers["x-api-key"] == "mm-key-123" + + def test_inject_cache_control_same_as_anthropic(self, adapter): + body = { + "system": "You are an assistant", + "messages": [{"role": "user", "content": "hi"}], + } + result = adapter.inject_cache_control(body, set()) + assert isinstance(result["system"], list) + assert result["system"][0]["cache_control"] == {"type": "ephemeral"} + + def test_inject_does_not_mutate_original(self, adapter): + body = {"system": "test", "messages": []} + original = copy.deepcopy(body) + adapter.inject_cache_control(body, set()) + assert body == original + + def test_parse_cache_metrics(self, adapter): + response = { + "usage": { + "cache_creation_input_tokens": 2000, + "cache_read_input_tokens": 1000, + "input_tokens": 3000, + "output_tokens": 500, + } + } + metrics = adapter.parse_cache_metrics(response) + assert metrics.cache_creation_tokens == 2000 + assert metrics.cache_read_tokens == 1000 diff --git a/tests/test_cloud_proxy_integration.py b/tests/test_cloud_proxy_integration.py new file mode 100644 index 0000000..1b242cb --- /dev/null +++ b/tests/test_cloud_proxy_integration.py @@ -0,0 +1,294 @@ +"""Integration tests for cloud prompt cache proxy flow. + +Tests the end-to-end interaction between: +- TTLEvictionPolicy (cache state tracking) +- CloudProviderAdapters (cache control injection + metrics parsing) +- The combined flow: request → inject → forward → parse → update state +""" + +import copy +import time +from unittest.mock import patch + +import pytest + +from contextpilot.server.ttl_eviction import TTLEvictionPolicy, TTLTier, CacheMetrics +from contextpilot.server.cloud_adapters import ( + get_cloud_adapter, + AnthropicAdapter, + OpenAIAdapter, + MiniMaxAdapter, +) + + +class TestEndToEndCacheFlow: + """Test the full cache lifecycle: inject → forward → parse → track.""" + + def _simulate_request_response(self, adapter, policy, body, response_body): + """Simulate one request-response cycle through the cloud proxy.""" + import hashlib + import json + + cached_hashes = policy.get_cached_hashes() + modified_body = adapter.inject_cache_control(body, cached_hashes) + + content_hash = hashlib.sha256( + json.dumps( + modified_body.get("system", ""), sort_keys=True, ensure_ascii=False + ).encode() + ).hexdigest()[:24] + + metrics = adapter.parse_cache_metrics(response_body) + policy.update_from_response(metrics, content_hash) + + return modified_body, metrics, content_hash + + def test_anthropic_first_request_creates_cache(self): + adapter = get_cloud_adapter("anthropic") + policy = TTLEvictionPolicy(default_ttl=TTLTier.SHORT) + + body = { + "system": "You are a helpful AI assistant.", + "messages": [{"role": "user", "content": "Hello!"}], + } + response = { + "usage": { + "cache_creation_input_tokens": 500, + "cache_read_input_tokens": 0, + "input_tokens": 510, + "output_tokens": 50, + } + } + + modified_body, metrics, content_hash = self._simulate_request_response( + adapter, policy, body, response + ) + + assert metrics.cache_creation_tokens == 500 + assert metrics.cache_read_tokens == 0 + assert policy.is_cached(content_hash) + assert policy.get_cached_count() == 1 + + def test_anthropic_second_request_hits_cache(self): + adapter = get_cloud_adapter("anthropic") + policy = TTLEvictionPolicy(default_ttl=TTLTier.SHORT) + + body = { + "system": "You are a helpful AI assistant.", + "messages": [{"role": "user", "content": "Hello!"}], + } + + # First request: cache write + resp1 = { + "usage": { + "cache_creation_input_tokens": 500, + "cache_read_input_tokens": 0, + "input_tokens": 510, + "output_tokens": 50, + } + } + _, _, hash1 = self._simulate_request_response(adapter, policy, body, resp1) + + # Second request: cache hit + body2 = copy.deepcopy(body) + body2["messages"] = [{"role": "user", "content": "Follow-up question."}] + resp2 = { + "usage": { + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 500, + "input_tokens": 510, + "output_tokens": 60, + } + } + _, metrics2, hash2 = self._simulate_request_response( + adapter, policy, body2, resp2 + ) + + assert metrics2.cache_read_tokens == 500 + assert metrics2.cache_creation_tokens == 0 + stats = policy.get_stats() + assert stats["total_hits"] >= 1 + + def test_ttl_expiry_triggers_recache(self): + adapter = get_cloud_adapter("anthropic") + policy = TTLEvictionPolicy(default_ttl=TTLTier.SHORT) + + body = { + "system": "You are a coding assistant.", + "messages": [{"role": "user", "content": "hi"}], + } + + # First request: create cache + resp1 = { + "usage": { + "cache_creation_input_tokens": 1000, + "cache_read_input_tokens": 0, + "input_tokens": 1010, + "output_tokens": 50, + } + } + _, _, hash1 = self._simulate_request_response(adapter, policy, body, resp1) + assert policy.is_cached(hash1) + + # Simulate 6 minutes passing (TTL=5min) + future = time.time() + 360 + with patch("contextpilot.server.ttl_eviction.time.time", return_value=future): + policy.evict_expired() + assert not policy.is_cached(hash1) + assert policy.get_cached_count() == 0 + + # Third request: must re-cache + resp3 = { + "usage": { + "cache_creation_input_tokens": 1000, + "cache_read_input_tokens": 0, + "input_tokens": 1010, + "output_tokens": 50, + } + } + _, metrics3, hash3 = self._simulate_request_response( + adapter, policy, body, resp3 + ) + assert metrics3.cache_creation_tokens == 1000 + assert policy.is_cached(hash3) + + def test_openai_extended_caching_tracking(self): + adapter = get_cloud_adapter("openai") + adapter.configured_ttl = TTLTier.LONG + policy = TTLEvictionPolicy( + default_ttl_seconds=adapter.get_extended_ttl_seconds(), + ) + + body = { + "model": "gpt-4.1", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + modified = adapter.inject_cache_control(body, policy.get_cached_hashes()) + assert modified["prompt_cache_retention"] == "24h" + assert "prompt_cache_retention" not in body + + # First response: no cache yet + resp1 = { + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 100, + "prompt_tokens_details": {"cached_tokens": 0}, + } + } + metrics1 = adapter.parse_cache_metrics(resp1) + assert metrics1.cache_read_tokens == 0 + + # Second response: cached + resp2 = { + "usage": { + "prompt_tokens": 1000, + "completion_tokens": 100, + "prompt_tokens_details": {"cached_tokens": 800}, + } + } + metrics2 = adapter.parse_cache_metrics(resp2) + assert metrics2.cache_read_tokens == 800 + assert metrics2.cache_creation_tokens == 200 + + def test_minimax_cache_flow(self): + adapter = get_cloud_adapter("minimax") + policy = TTLEvictionPolicy(default_ttl=TTLTier.SHORT) + + body = { + "system": "You are a literary analysis assistant.", + "messages": [{"role": "user", "content": "Analyze themes in this book."}], + } + + modified, metrics, content_hash = self._simulate_request_response( + adapter, + policy, + body, + { + "usage": { + "cache_creation_input_tokens": 188086, + "cache_read_input_tokens": 0, + "input_tokens": 21, + "output_tokens": 393, + } + }, + ) + + # System should have cache_control injected + assert isinstance(modified["system"], list) + assert modified["system"][0]["cache_control"] == {"type": "ephemeral"} + + # Cache state updated + assert policy.is_cached(content_hash) + assert policy.get_total_cached_tokens() == 188086 + + +class TestMultiProviderCacheIsolation: + """Verify each provider's cache state is independent.""" + + def test_separate_policies_per_provider(self): + policies = { + name: TTLEvictionPolicy( + default_ttl_seconds=get_cloud_adapter(name).get_default_ttl_seconds() + ) + for name in ["anthropic", "openai", "minimax"] + } + + policies["anthropic"].add_entry("shared_hash", token_count=100) + assert policies["anthropic"].is_cached("shared_hash") + assert not policies["openai"].is_cached("shared_hash") + assert not policies["minimax"].is_cached("shared_hash") + + +class TestLongTTLTier: + """Test 1-hour TTL tier.""" + + def test_long_ttl_survives_5min(self): + policy = TTLEvictionPolicy(default_ttl=TTLTier.LONG) + policy.add_entry("long_lived", token_count=5000) + + future_6min = time.time() + 360 + with patch( + "contextpilot.server.ttl_eviction.time.time", return_value=future_6min + ): + policy.evict_expired() + assert policy.is_cached("long_lived") + + def test_long_ttl_expires_after_24hr(self): + policy = TTLEvictionPolicy(default_ttl=TTLTier.LONG, default_ttl_seconds=86400) + policy.add_entry("long_lived", content_hash="long_lived", token_count=5000) + + future_25hr = time.time() + 90000 + with patch( + "contextpilot.server.ttl_eviction.time.time", return_value=future_25hr + ): + evicted = policy.evict_expired() + assert len(evicted) == 1 + assert evicted[0].content_hash == "long_lived" + + +class TestCacheStatsAccumulation: + """Test that cache statistics accumulate correctly over multiple requests.""" + + def test_stats_accumulate_across_requests(self): + adapter = get_cloud_adapter("anthropic") + policy = TTLEvictionPolicy() + + for i in range(5): + policy.add_entry(f"hash_{i}", token_count=100 * (i + 1)) + + for i in range(5): + policy.is_cached(f"hash_{i}") + policy.is_cached("nonexistent_1") + policy.is_cached("nonexistent_2") + + stats = policy.get_stats() + assert stats["active_entries"] == 5 + assert stats["total_hits"] == 5 + assert stats["total_misses"] == 2 + assert stats["total_additions"] == 5 + assert stats["total_cached_tokens"] == 100 + 200 + 300 + 400 + 500 + assert stats["hit_rate_pct"] == pytest.approx(5 / 7 * 100, abs=0.1) diff --git a/tests/test_http_intercept.py b/tests/test_http_intercept.py new file mode 100644 index 0000000..d746595 --- /dev/null +++ b/tests/test_http_intercept.py @@ -0,0 +1,1015 @@ +"""Integration tests for HTTP intercept proxy endpoints. + +Uses httpx AsyncClient with FastAPI's TestClient (no real server needed) +and patches aiohttp to mock the backend LLM server. +""" + +import json +import pytest +import asyncio +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi.testclient import TestClient + +from contextpilot.server.http_server import app, _init_config + +import contextpilot.server.http_server as http_mod + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +class FakeHeaders: + """Fake multidict-like headers for aiohttp responses.""" + + def __init__(self, headers=None): + self._headers = headers or {"content-type": "application/json"} + + def items(self): + return self._headers.items() + + def get(self, key, default=None): + return self._headers.get(key, default) + + +class FakeResponse: + """Fake aiohttp response for mocking.""" + + def __init__(self, json_body, status=200): + self._json = json_body + self.status = status + self.content = FakeStreamContent() + self.headers = FakeHeaders() + + async def json(self): + return self._json + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + +class FakeStreamContent: + """Fake async iterator for streaming response.""" + + def __init__(self, chunks=None): + self._chunks = chunks or [b"data: {}\n\n", b"data: [DONE]\n\n"] + + async def iter_any(self): + for c in self._chunks: + yield c + + +class FakeStreamResponse(FakeResponse): + """Fake response that yields SSE chunks.""" + + def __init__(self, chunks=None, status=200): + super().__init__({}, status) + self.content = FakeStreamContent(chunks) + self.headers = FakeHeaders({"content-type": "text/event-stream"}) + + +class FakeSession: + """Fake aiohttp.ClientSession.""" + + def __init__(self, response=None): + self._response = response or FakeResponse( + { + "choices": [{"message": {"content": "Hello"}}], + "usage": {"total_tokens": 10}, + } + ) + + def post(self, url, json=None, headers=None): + self._last_url = url + self._last_json = json + self._last_headers = headers + return self._response + + def get(self, url, headers=None): + self._last_url = url + self._last_headers = headers + return self._response + + +def _cp_meta(resp): + """Extract ContextPilot metadata from the X-ContextPilot-Result response header.""" + raw = resp.headers.get("x-contextpilot-result") + if raw is None: + return {} + return json.loads(raw) + + +@pytest.fixture +def mock_session(): + """Provide a FakeSession and patch it into http_server globals.""" + session = FakeSession() + return session + + +@pytest.fixture +def client(mock_session): + """FastAPI test client with mocked backend.""" + # Patch the module-level globals + original_session = http_mod._aiohttp_session + original_url = http_mod._infer_api_url + original_intercept_index = http_mod._intercept_index + original_state = http_mod._intercept_state + http_mod._aiohttp_session = mock_session + http_mod._infer_api_url = "http://mock-backend:30000" + http_mod._intercept_index = None # reset so each test starts fresh + http_mod._intercept_state = http_mod._InterceptConvState() + try: + yield TestClient(app, raise_server_exceptions=False) + finally: + http_mod._aiohttp_session = original_session + http_mod._infer_api_url = original_url + http_mod._intercept_index = original_intercept_index + http_mod._intercept_state = original_state + + +# ============================================================================ +# Helpers +# ============================================================================ + + +def _warmup(client, path, body): + """Prime the intercept index so subsequent calls use build_incremental. + + After priming, resets conversation state so the actual test request + sees a clean slate (but the clustering index remains primed). + """ + resp = client.post(path, json=body) + assert resp.status_code == 200 + # Keep _intercept_index primed, but reset conversation tracking. + http_mod._intercept_state = http_mod._InterceptConvState() + return resp + + +# Documents with clear clustering signal: two auth docs share words +# (token, authentication, secure, for) while the database doc shares none. +# schedule_only produces [0, 2, 1] — auth docs grouped, database last. +_AUTH_DOC = "JWT token validation and rotation policy for secure authentication" +_DB_DOC = "Database connection pooling sharding replication backup strategy" +_OAUTH_DOC = "OAuth2 authentication token refresh flow for secure login sessions" + + +# ============================================================================ +# Full intercept flow +# ============================================================================ + + +class TestOpenAIIntercept: + def test_first_request_builds_index_no_reorder(self, client, mock_session): + """First request builds index but does NOT reorder (index empty, 1 context).""" + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ), + }, + {"role": "user", "content": "Summarize the documents."}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + assert mock_session._last_url == "http://mock-backend:30000/v1/chat/completions" + # First request: index empty → no reorder, no header. + meta = _cp_meta(resp) + assert meta == {} + # Body forwarded unmodified. + forwarded = mock_session._last_json + sys_content = forwarded["messages"][0]["content"] + assert "" in sys_content + + def test_second_session_may_reorder(self, client, mock_session): + """After index is built (session 1), session 2 uses build_incremental.""" + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": ( + "\n" + "JWT token validation and rotation policy for secure authentication\n" + "Database connection pooling sharding replication backup strategy\n" + "OAuth2 authentication token refresh flow for secure login sessions\n" + "" + ), + }, + {"role": "user", "content": "Summarize."}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + # Body is forwarded (may or may not be reordered depending on index state). + assert "_contextpilot" not in resp.json() + forwarded = mock_session._last_json + sys_content = forwarded["messages"][0]["content"] + assert "" in sys_content + assert "" in sys_content + + def test_bypass_when_disabled(self, client, mock_session): + """When X-ContextPilot-Enabled: false, forward unmodified.""" + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": "AB", + }, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post( + "/v1/chat/completions", + json=body, + headers={"X-ContextPilot-Enabled": "false"}, + ) + assert resp.status_code == 200 + # Body forwarded unmodified + forwarded = mock_session._last_json + assert forwarded["messages"][0]["content"] == body["messages"][0]["content"] + # No contextpilot metadata + assert _cp_meta(resp) == {} + + def test_bypass_when_no_docs(self, client, mock_session): + """When system message has no extractable docs, forward unmodified.""" + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + assert forwarded["messages"][0]["content"] == "You are a helpful assistant." + + def test_bypass_when_single_doc(self, client, mock_session): + """Single doc => nothing to reorder, forward unmodified.""" + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": "Only one", + }, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + assert _cp_meta(resp) == {} + + def test_numbered_format(self, client, mock_session): + """Numbered format extraction and forwarding.""" + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": "[1] First document [2] Second document [3] Third document", + }, + {"role": "user", "content": "Summarize"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + sys_content = forwarded["messages"][0]["content"] + assert "[1]" in sys_content + assert "[2]" in sys_content + + +class TestAnthropicIntercept: + def test_basic_intercept(self, client, mock_session): + body = { + "model": "claude-3-opus-20240229", + "system": "ABC", + "messages": [{"role": "user", "content": "Hello"}], + } + resp = client.post("/v1/messages", json=body) + assert resp.status_code == 200 + assert mock_session._last_url == "http://mock-backend:30000/v1/messages" + forwarded = mock_session._last_json + assert "" in forwarded["system"] + + def test_bypass_no_system(self, client, mock_session): + body = { + "model": "claude-3-opus-20240229", + "messages": [{"role": "user", "content": "Hello"}], + } + resp = client.post("/v1/messages", json=body) + assert resp.status_code == 200 + + +# ============================================================================ +# Streaming +# ============================================================================ + + +class TestStreaming: + def test_streaming_passthrough(self, mock_session): + """Streaming responses are passed through.""" + chunks = [b'data: {"id":"1"}\n\n', b"data: [DONE]\n\n"] + stream_resp = FakeStreamResponse(chunks) + session = FakeSession(stream_resp) + + original_session = http_mod._aiohttp_session + original_url = http_mod._infer_api_url + http_mod._aiohttp_session = session + http_mod._infer_api_url = "http://mock-backend:30000" + try: + client = TestClient(app, raise_server_exceptions=False) + body = { + "model": "gpt-4", + "stream": True, + "messages": [ + { + "role": "system", + "content": "AB", + }, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + # Streaming response content + content = resp.content + assert b"data:" in content + finally: + http_mod._aiohttp_session = original_session + http_mod._infer_api_url = original_url + + +# ============================================================================ +# Catch-all still works +# ============================================================================ + + +class TestCatchAllProxy: + def test_other_v1_paths_still_proxied(self, client, mock_session): + """Other /v1/* paths still go to the catch-all proxy.""" + mock_session._response = FakeResponse({"models": [{"id": "gpt-4"}]}) + resp = client.get("/v1/models") + assert resp.status_code == 200 + assert mock_session._last_url == "http://mock-backend:30000/v1/models" + + def test_completions_still_works(self, client, mock_session): + """POST /v1/completions still handled by existing proxy.""" + mock_session._response = FakeResponse( + {"choices": [{"text": "hello"}], "usage": {"total_tokens": 5}} + ) + body = {"model": "gpt-4", "prompt": "Hello"} + resp = client.post("/v1/completions", json=body) + assert resp.status_code == 200 + + def test_completions_metadata_in_header(self, client, mock_session): + """proxy_completions puts metadata in X-ContextPilot-Result header, not body.""" + mock_session._response = FakeResponse( + {"choices": [{"text": "hello"}], "usage": {"total_tokens": 5}} + ) + body = {"model": "gpt-4", "prompt": "Hello", "request_id": "req-abc123"} + resp = client.post("/v1/completions", json=body) + assert resp.status_code == 200 + # Metadata should be in header + meta = _cp_meta(resp) + assert meta.get("request_id") == "req-abc123" + assert meta.get("tokens_reported") == 5 + # Body should NOT contain _contextpilot + assert "_contextpilot" not in resp.json() + + +# ============================================================================ +# Header forwarding +# ============================================================================ + + +class TestHeaderForwarding: + def test_auth_headers_forwarded(self, client, mock_session): + """Authorization headers are forwarded, X-ContextPilot-* stripped.""" + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "Plain text."}, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post( + "/v1/chat/completions", + json=body, + headers={ + "Authorization": "Bearer sk-test", + "X-ContextPilot-Mode": "auto", + }, + ) + assert resp.status_code == 200 + outbound = mock_session._last_headers or {} + # Auth should be forwarded + assert outbound.get("authorization") or outbound.get("Authorization") + # X-ContextPilot-* should be stripped + for k in outbound: + assert not k.lower().startswith("x-contextpilot-") + + +# ============================================================================ +# Tool result intercept +# ============================================================================ + + +class TestToolResultIntercept: + def test_openai_tool_result_forwarded(self, client, mock_session): + """OpenAI tool results with docs are extracted and forwarded.""" + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Search for X"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc1", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "tc1", + "content": ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ), + }, + {"role": "user", "content": "Now summarize."}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + tool_content = forwarded["messages"][3]["content"] + assert "" in tool_content + + def test_anthropic_tool_result_forwarded(self, client, mock_session): + """Anthropic tool_result content blocks are extracted and forwarded.""" + body = { + "model": "claude-3-opus-20240229", + "system": "You are a helper.", + "messages": [ + {"role": "user", "content": "Search for X"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tu1", + "name": "search", + "input": {"query": "X"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tu1", + "content": ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ), + }, + ], + }, + {"role": "user", "content": "Now summarize."}, + ], + } + resp = client.post("/v1/messages", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + tr_content = forwarded["messages"][2]["content"][0]["content"] + assert "" in tr_content + + def test_openai_json_tool_result_forwarded(self, client, mock_session): + """OpenClaw-style JSON tool results are forwarded correctly.""" + import json as _json + + results = [ + {"path": "auth.md", "snippet": _AUTH_DOC, "score": 0.9}, + {"path": "db.md", "snippet": _DB_DOC, "score": 0.8}, + {"path": "oauth.md", "snippet": _OAUTH_DOC, "score": 0.7}, + ] + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Search for auth"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc1", + "type": "function", + "function": {"name": "memory_search", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "tc1", + "content": _json.dumps( + {"results": results, "citations": "auto"}, indent=2 + ), + }, + {"role": "user", "content": "Now summarize."}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + tool_content = _json.loads(forwarded["messages"][3]["content"]) + assert "results" in tool_content + assert len(tool_content["results"]) == 3 + assert tool_content["citations"] == "auto" + + def test_anthropic_json_tool_result_forwarded(self, client, mock_session): + """Anthropic-format JSON tool results are forwarded correctly.""" + import json as _json + + results = [ + {"title": "Auth guide", "url": "https://a.com", "description": _AUTH_DOC}, + {"title": "DB guide", "url": "https://b.com", "description": _DB_DOC}, + {"title": "OAuth guide", "url": "https://c.com", "description": _OAUTH_DOC}, + ] + body = { + "model": "claude-3-opus-20240229", + "system": "You are a helper.", + "messages": [ + {"role": "user", "content": "Search the web"}, + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "tu1", + "name": "web_search", + "input": {"query": "auth"}, + }, + ], + }, + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "tu1", + "content": _json.dumps( + {"results": results, "provider": "brave"}, indent=2 + ), + }, + ], + }, + {"role": "user", "content": "Summarize."}, + ], + } + resp = client.post("/v1/messages", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + tr_content = _json.loads(forwarded["messages"][2]["content"][0]["content"]) + assert "results" in tr_content + assert len(tr_content["results"]) == 3 + assert tr_content["provider"] == "brave" + + +# ============================================================================ +# RID injection in intercept (stateful mode) +# ============================================================================ + + +class TestInterceptRidInjection: + def test_rid_injected_in_stateful_mode(self, mock_session): + """In stateful mode, intercept injects rid into forwarded body.""" + original_session = http_mod._aiohttp_session + original_url = http_mod._infer_api_url + original_stateless = http_mod._stateless_mode + original_index = http_mod._index + http_mod._aiohttp_session = mock_session + http_mod._infer_api_url = "http://mock-backend:30000" + http_mod._stateless_mode = False + http_mod._index = MagicMock() # non-None → stateful + try: + client = TestClient(app, raise_server_exceptions=False) + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "Plain text."}, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + assert "rid" in forwarded + assert forwarded["rid"].startswith("req-") + finally: + http_mod._aiohttp_session = original_session + http_mod._infer_api_url = original_url + http_mod._stateless_mode = original_stateless + http_mod._index = original_index + + def test_no_rid_in_stateless_mode(self, client, mock_session): + """In stateless mode (default for tests), no rid is injected.""" + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "Plain text."}, + {"role": "user", "content": "Hello"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + assert "rid" not in forwarded + + +# ============================================================================ +# Conversation-aware intercept (skip old, dedup new) +# ============================================================================ + + +# Distinct document sets for multi-turn dedup testing. +_DOC_CACHE = "Redis cache invalidation strategy with TTL and LRU eviction" +_DOC_DEPLOY = "Kubernetes deployment rolling update blue green canary strategy" + + +class TestConversationAwareIntercept: + """Tests for multi-turn skip / dedup / reorder behaviour.""" + + def _make_body(self, system, tool_results=None): + """Build an OpenAI chat body with optional tool result messages.""" + messages = [ + {"role": "system", "content": system}, + {"role": "user", "content": "Hello"}, + ] + if tool_results: + for i, content in enumerate(tool_results): + messages.append( + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": f"tc{i}", + "type": "function", + "function": {"name": "search", "arguments": "{}"}, + } + ], + } + ) + messages.append( + {"role": "tool", "tool_call_id": f"tc{i}", "content": content} + ) + messages.append({"role": "user", "content": f"Follow-up {i}"}) + return {"model": "gpt-4", "messages": messages} + + def test_old_tool_result_skipped_on_second_turn(self, client, mock_session): + tool_content = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ) + system = "You are a helpful assistant." + # Turn 1: first tool result, builds index (no reorder on first call). + body1 = self._make_body(system, [tool_content]) + resp1 = client.post("/v1/chat/completions", json=body1) + assert resp1.status_code == 200 + + # Turn 2: same body again — tool result should be skipped. + resp2 = client.post("/v1/chat/completions", json=body1) + assert resp2.status_code == 200 + meta2 = _cp_meta(resp2) + assert meta2 == {} + forwarded = mock_session._last_json + assert forwarded["messages"][3]["content"] == tool_content + + def test_new_tool_result_processed_old_skipped(self, client, mock_session): + system = "You are a helpful assistant." + tool1 = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ) + # New tool result with 3 distinct docs for reliable clustering. + tool2 = ( + f"\n" + f"{_DOC_CACHE}\n" + f"{_DOC_DEPLOY}\n" + f"API rate limiting throttling circuit breaker backpressure\n" + f"" + ) + # Turn 1: one tool result. + body1 = self._make_body(system, [tool1]) + resp1 = client.post("/v1/chat/completions", json=body1) + assert resp1.status_code == 200 + + # Turn 2: old tool result + new one. + body2 = self._make_body(system, [tool1, tool2]) + resp2 = client.post("/v1/chat/completions", json=body2) + assert resp2.status_code == 200 + meta2 = _cp_meta(resp2) + assert meta2 == {} + forwarded = mock_session._last_json + tool2_forwarded = forwarded["messages"][6]["content"] + assert tool2_forwarded.count("") == 3 + assert _DOC_CACHE in tool2_forwarded + assert _DOC_DEPLOY in tool2_forwarded + + def test_cross_tool_result_dedup(self, client, mock_session): + """Documents seen in a previous tool result get deduped in new ones.""" + system = "You are a helpful assistant." + tool1 = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ) + # Second search returns 2 OLD docs + 2 NEW docs. + tool2 = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_OAUTH_DOC}\n" + f"{_DOC_CACHE}\n" + f"{_DOC_DEPLOY}\n" + f"" + ) + # Turn 1 + body1 = self._make_body(system, [tool1]) + resp1 = client.post("/v1/chat/completions", json=body1) + assert resp1.status_code == 200 + + # Turn 2 with overlapping docs + body2 = self._make_body(system, [tool1, tool2]) + resp2 = client.post("/v1/chat/completions", json=body2) + assert resp2.status_code == 200 + meta2 = _cp_meta(resp2) + assert meta2.get("intercepted") is True + assert ( + meta2.get("documents_deduplicated", 0) == 2 + ) # AUTH + OAUTH deduped in tool2 + + def test_system_prompt_processed_once(self, client, mock_session): + """System prompt docs are only processed on the first turn. + + First turn: index empty → no reorder (system_processed set to True). + Second turn: system_processed=True → skipped. + """ + system = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ) + body = self._make_body(system) + # Turn 1: system processed (no reorder, but flag set). + resp1 = client.post("/v1/chat/completions", json=body) + assert resp1.status_code == 200 + + # Turn 2: system NOT re-processed. + resp2 = client.post("/v1/chat/completions", json=body) + assert resp2.status_code == 200 + meta2 = _cp_meta(resp2) + assert meta2 == {} or meta2.get("sources", {}).get("system", 0) == 0 + + def test_new_session_resets_state(self, client, mock_session): + """A shorter messages array (new session) resets all intercept state.""" + system = "You are a helpful assistant." + tool_content = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_DB_DOC}\n" + f"{_OAUTH_DOC}\n" + f"" + ) + tool_overlap = ( + f"\n" + f"{_AUTH_DOC}\n" + f"{_OAUTH_DOC}\n" + f"{_DOC_CACHE}\n" + f"{_DOC_DEPLOY}\n" + f"" + ) + # Session 1: long conversation with tool result. + body_long = self._make_body(system, [tool_content]) + resp1 = client.post("/v1/chat/completions", json=body_long) + assert resp1.status_code == 200 + + body_overlap = self._make_body(system, [tool_content, tool_overlap]) + resp2 = client.post("/v1/chat/completions", json=body_overlap) + meta2 = _cp_meta(resp2) + assert meta2.get("documents_deduplicated", 0) == 2 + + # Session 2: shorter messages → triggers state reset. + body_short = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": system}, + {"role": "user", "content": "Fresh session"}, + ], + } + client.post("/v1/chat/completions", json=body_short) + + resp3 = client.post( + "/v1/chat/completions", json=self._make_body(system, [tool_overlap]) + ) + assert resp3.status_code == 200 + meta3 = _cp_meta(resp3) + assert meta3.get("documents_deduplicated", 0) == 0 + forwarded = mock_session._last_json + tool3_content = forwarded["messages"][3]["content"] + assert _AUTH_DOC in tool3_content + assert _OAUTH_DOC in tool3_content + + def test_json_tool_result_dedup(self, client, mock_session): + """JSON results format also gets cross-turn dedup. + + Note: dedup compares the full serialised JSON document string, so + entries must be byte-identical to be considered duplicates. + """ + import json as _json + + system = "You are a helper." + # Shared entry — identical dict so serialised form matches. + auth_entry = {"path": "auth.md", "snippet": _AUTH_DOC, "score": 0.9} + results1 = [ + auth_entry, + {"path": "db.md", "snippet": _DB_DOC, "score": 0.8}, + {"path": "oauth.md", "snippet": _OAUTH_DOC, "score": 0.7}, + ] + results2 = [ + auth_entry, # exact duplicate + {"path": "cache.md", "snippet": _DOC_CACHE, "score": 0.85}, + {"path": "deploy.md", "snippet": _DOC_DEPLOY, "score": 0.75}, + ] + tool1 = _json.dumps({"results": results1}, indent=2) + tool2 = _json.dumps({"results": results2}, indent=2) + + # Turn 1 + body1 = self._make_body(system, [tool1]) + resp1 = client.post("/v1/chat/completions", json=body1) + assert resp1.status_code == 200 + + # Turn 2: old tool + new tool with overlapping auth entry. + body2 = self._make_body(system, [tool1, tool2]) + resp2 = client.post("/v1/chat/completions", json=body2) + assert resp2.status_code == 200 + meta2 = _cp_meta(resp2) + assert meta2.get("intercepted") is True + assert ( + meta2.get("documents_deduplicated", 0) == 1 + ) # auth entry deduped in tool2 + forwarded = mock_session._last_json + tool2_forwarded = _json.loads(forwarded["messages"][6]["content"]) + assert len(tool2_forwarded["results"]) == 2 + + +# ============================================================================ +# External content marker stripping +# ============================================================================ + + +class TestExternalContentIdStripping: + """EXTERNAL_UNTRUSTED_CONTENT random ids are stripped before forwarding.""" + + def _wrap(self, text, marker_id="ab12cd34"): + """Simulate OpenClaw's wrapWebContent.""" + return ( + f'\n<<>>\n' + f"Source: Web Search\n---\n{text}\n" + f'<<>>' + ) + + def test_ids_stripped_from_forwarded_body(self, client, mock_session): + """Random marker ids are removed so identical content shares KV prefix.""" + wrapped_title = self._wrap("Example Title", "aabbccdd11223344") + wrapped_desc = self._wrap("Example description text", "1122334455667788") + import json as _json + + results = [ + { + "title": wrapped_title, + "url": "https://a.com", + "description": wrapped_desc, + }, + ] + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Search"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc1", + "type": "function", + "function": {"name": "web_search", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "tc1", + "content": _json.dumps({"results": results}), + }, + {"role": "user", "content": "Summarize"}, + ], + } + resp = client.post("/v1/chat/completions", json=body) + assert resp.status_code == 200 + forwarded = mock_session._last_json + tool_content = forwarded["messages"][3]["content"] + # Random ids should be stripped + assert 'id="aabbccdd11223344"' not in tool_content + assert 'id="1122334455667788"' not in tool_content + # Markers themselves preserved (without id) + assert "<<>>" in tool_content + assert "<<>>" in tool_content + + def test_different_ids_produce_same_forwarded_content(self, client, mock_session): + """Two requests with different random ids produce identical forwarded content.""" + import json as _json + + def _make_body(marker_id): + wrapped = self._wrap("Same content", marker_id) + results = [ + {"title": wrapped, "url": "https://a.com", "description": "plain"} + ] + return { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Hello"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "tc1", + "type": "function", + "function": {"name": "s", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "tc1", + "content": _json.dumps({"results": results}), + }, + {"role": "user", "content": "Go"}, + ], + } + + # Request 1 with id "aaaa" + resp1 = client.post("/v1/chat/completions", json=_make_body("aaaa0000bbbb1111")) + assert resp1.status_code == 200 + content1 = mock_session._last_json["messages"][3]["content"] + + # Reset intercept state for clean comparison + http_mod._intercept_state = http_mod._InterceptConvState() + + # Request 2 with different id "bbbb" + resp2 = client.post("/v1/chat/completions", json=_make_body("cccc2222dddd3333")) + assert resp2.status_code == 200 + content2 = mock_session._last_json["messages"][3]["content"] + + assert content1 == content2 diff --git a/tests/test_intercept_parser.py b/tests/test_intercept_parser.py new file mode 100644 index 0000000..0b05c9f --- /dev/null +++ b/tests/test_intercept_parser.py @@ -0,0 +1,1076 @@ +"""Unit tests for contextpilot.server.intercept_parser.""" + +import copy +import pytest +from contextpilot.server.intercept_parser import ( + InterceptConfig, + ExtractionResult, + ToolResultLocation, + MultiExtractionResult, + parse_intercept_headers, + extract_documents, + extract_from_openai_chat, + extract_from_anthropic_messages, + extract_from_openai_tool_results, + extract_from_anthropic_tool_results, + extract_all_openai, + extract_all_anthropic, + reconstruct_openai_chat, + reconstruct_anthropic_messages, + reconstruct_openai_tool_result, + reconstruct_anthropic_tool_result, + reconstruct_content, +) + + +# ============================================================================ +# Header parsing +# ============================================================================ + + +class TestParseHeaders: + def test_defaults(self): + config = parse_intercept_headers({}) + assert config.enabled is True + assert config.mode == "auto" + assert config.tag == "document" + assert config.separator == "---" + assert config.alpha == pytest.approx(0.001) + assert config.linkage_method == "average" + + def test_explicit_mode(self): + headers = { + "X-ContextPilot-Mode": "xml_tag", + "X-ContextPilot-Tag": "passage", + "X-ContextPilot-Alpha": "0.01", + "X-ContextPilot-Linkage": "complete", + } + config = parse_intercept_headers(headers) + assert config.mode == "xml_tag" + assert config.tag == "passage" + assert config.alpha == pytest.approx(0.01) + assert config.linkage_method == "complete" + + def test_disabled(self): + for val in ("false", "0", "no", "False", "NO"): + config = parse_intercept_headers({"X-ContextPilot-Enabled": val}) + assert config.enabled is False + + def test_case_insensitive_keys(self): + headers = {"x-contextpilot-mode": "numbered"} + config = parse_intercept_headers(headers) + assert config.mode == "numbered" + + def test_separator_header(self): + headers = {"X-ContextPilot-Separator": "==="} + config = parse_intercept_headers(headers) + assert config.separator == "===" + + def test_scope_header(self): + config = parse_intercept_headers({"X-ContextPilot-Scope": "tool_results"}) + assert config.scope == "tool_results" + + def test_scope_default(self): + config = parse_intercept_headers({}) + assert config.scope == "all" + + def test_scope_invalid_falls_back(self): + config = parse_intercept_headers({"X-ContextPilot-Scope": "invalid"}) + assert config.scope == "all" + + +# ============================================================================ +# XML tag extraction +# ============================================================================ + + +class TestXmlExtraction: + def test_basic_documents_wrapper(self): + text = "\nDoc A\nDoc B\n" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "xml_tag" + assert result.documents == ["Doc A", "Doc B"] + assert result.wrapper_tag == "documents" + assert result.item_tag == "document" + + def test_contexts_wrapper(self): + text = "FirstSecond" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["First", "Second"] + assert result.wrapper_tag == "contexts" + + def test_docs_wrapper(self): + text = "ABC" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["A", "B", "C"] + + def test_passages_wrapper(self): + text = "XY" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["X", "Y"] + + def test_references_wrapper(self): + text = "Ref1Ref2" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["Ref1", "Ref2"] + + def test_custom_tag_explicit_mode(self): + text = "Code ACode B" + config = InterceptConfig(mode="xml_tag", tag="snippet") + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["Code A", "Code B"] + + def test_prefix_suffix_preserved(self): + text = "Here are the docs:\n\nA\nB\n\nPlease answer." + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.prefix == "Here are the docs:\n" + assert result.suffix == "\nPlease answer." + + def test_no_wrapper_multiple_items(self): + text = "Doc 1\nDoc 2" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["Doc 1", "Doc 2"] + assert result.wrapper_tag == "" + + def test_multiline_content(self): + text = "\nLine 1\nLine 2\nLine 3\nLine 4\n" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert len(result.documents) == 2 + assert "Line 1\nLine 2" in result.documents[0] + + def test_single_item_returns_none_without_wrapper(self): + text = "Only one doc" + config = InterceptConfig() + result = extract_documents(text, config) + # No wrapper and only 1 item -> None (need >=2 for reordering) + assert result is None + + +# ============================================================================ +# Numbered extraction +# ============================================================================ + + +class TestNumberedExtraction: + def test_basic_numbered(self): + text = "[1] First document [2] Second document [3] Third document" + config = InterceptConfig(mode="numbered") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "numbered" + assert result.documents == ["First document", "Second document", "Third document"] + + def test_numbered_with_prefix(self): + text = "Retrieved documents:\n[1] Doc A [2] Doc B" + config = InterceptConfig(mode="numbered") + result = extract_documents(text, config) + assert result is not None + assert result.prefix == "Retrieved documents:\n" + assert result.documents == ["Doc A", "Doc B"] + + def test_numbered_with_newlines(self): + text = "[1] First doc\n[2] Second doc\n[3] Third doc" + config = InterceptConfig(mode="numbered") + result = extract_documents(text, config) + assert result is not None + assert len(result.documents) == 3 + + def test_single_numbered_returns_none(self): + text = "[1] Only one document" + config = InterceptConfig(mode="numbered") + result = extract_documents(text, config) + assert result is None + + +# ============================================================================ +# Separator extraction +# ============================================================================ + + +class TestSeparatorExtraction: + def test_basic_separator(self): + text = "Doc A\n---\nDoc B\n---\nDoc C" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "separator" + assert result.documents == ["Doc A", "Doc B", "Doc C"] + + def test_equals_separator(self): + text = "Doc A\n===\nDoc B\n===\nDoc C" + config = InterceptConfig(mode="separator", separator="===") + result = extract_documents(text, config) + assert result is not None + assert result.documents == ["Doc A", "Doc B", "Doc C"] + + def test_auto_excludes_separator(self): + """separator excluded from auto — matches YAML frontmatter/structural content.""" + for sep in ("---", "==="): + text = f"First\n{sep}\nSecond\n{sep}\nThird" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is None + + def test_explicit_triple_dash(self): + text = "First\n---\nSecond\n---\nThird" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "separator" + assert result.documents == ["First", "Second", "Third"] + + def test_explicit_triple_equals(self): + text = "First\n===\nSecond\n===\nThird" + config = InterceptConfig(mode="separator", separator="===") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "separator" + + def test_single_separator_returns_none(self): + # Only one separator -> only 2 parts, still need >=2 docs + text = "Only one\n---\nTwo parts" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + # 2 docs is fine (>=2) + assert result is not None + assert len(result.documents) == 2 + + def test_no_separator_returns_none(self): + text = "Just plain text with no separators" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + assert result is None + + +# ============================================================================ +# JSON results extraction (OpenClaw tool results) +# ============================================================================ + + +class TestJsonResultsExtraction: + """OpenClaw tools return JSON.stringify(payload, null, 2) with a results array.""" + + def test_memory_search_results(self): + import json + text = json.dumps({ + "results": [ + {"path": "MEMORY.md", "snippet": "Use TypeScript", "score": 0.9}, + {"path": "notes.md", "snippet": "Prefer functional style", "score": 0.8}, + {"path": "config.md", "snippet": "Port 8080", "score": 0.7}, + ], + "citations": "auto", + }, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "json_results" + assert len(result.documents) == 3 + # Documents are path identifiers (used for clustering) + assert result.documents[0] == "MEMORY.md" + assert result.documents[1] == "notes.md" + assert result.documents[2] == "config.md" + # Full objects stored in json_items + assert result.json_items is not None + assert result.json_items[0]["path"] == "MEMORY.md" + + def test_web_search_results(self): + import json + text = json.dumps({ + "query": "python async", + "provider": "brave", + "results": [ + {"title": "Async IO in Python", "url": "https://a.com", "description": "Guide to asyncio"}, + {"title": "Python concurrency", "url": "https://b.com", "description": "Threading vs async"}, + ], + }, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "json_results" + assert len(result.documents) == 2 + # Documents are URL identifiers (used for clustering) + assert result.documents[0] == "https://a.com" + assert result.documents[1] == "https://b.com" + + def test_explicit_mode(self): + import json + text = json.dumps({"results": [{"a": 1}, {"b": 2}]}, indent=2) + config = InterceptConfig(mode="json_results") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "json_results" + assert len(result.documents) == 2 + + def test_single_result_returns_none(self): + import json + text = json.dumps({"results": [{"a": 1}]}, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + # Need >= 2 results to reorder + assert result is None + + def test_no_results_key_returns_none(self): + import json + text = json.dumps({"data": [1, 2, 3]}, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is None + + def test_not_json_returns_none(self): + text = "This is just plain text, not JSON" + config = InterceptConfig(mode="json_results") + result = extract_documents(text, config) + assert result is None + + def test_roundtrip(self): + import json + original = { + "results": [ + {"path": "a.md", "snippet": "alpha"}, + {"path": "b.md", "snippet": "beta"}, + {"path": "c.md", "snippet": "gamma"}, + ], + "citations": "on", + } + text = json.dumps(original, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + # Documents are path identifiers + assert result.documents == ["a.md", "b.md", "c.md"] + # Reorder: reverse the content strings + reordered = list(reversed(result.documents)) + rebuilt = reconstruct_content(result, reordered) + rebuilt_obj = json.loads(rebuilt) + # Top-level keys preserved + assert rebuilt_obj["citations"] == "on" + # Full result objects reordered by content mapping + assert rebuilt_obj["results"][0]["path"] == "c.md" + assert rebuilt_obj["results"][1]["path"] == "b.md" + assert rebuilt_obj["results"][2]["path"] == "a.md" + + +# ============================================================================ +# Auto detection priority +# ============================================================================ + + +class TestAutoDetection: + def test_xml_takes_priority(self): + text = "AB" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "xml_tag" + + def test_numbered_before_separator(self): + text = "[1] Doc A [2] Doc B" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "numbered" + + def test_json_results_before_separator(self): + import json + text = json.dumps({"results": [{"a": 1}, {"b": 2}]}, indent=2) + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "json_results" + + def test_separator_excluded_from_auto(self): + """separator is excluded from auto — matches YAML frontmatter too aggressively.""" + text = "First doc\n---\nSecond doc\n---\nThird doc" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is None + + def test_separator_explicit_mode(self): + text = "First doc\n---\nSecond doc\n---\nThird doc" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "separator" + + def test_nothing_returns_none(self): + text = "Just a plain system message with no documents." + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is None + + +# ============================================================================ +# Reconstruction roundtrips +# ============================================================================ + + +class TestReconstruction: + def test_xml_roundtrip(self): + text = "Prefix\n\nA\nB\nC\n\nSuffix" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + rebuilt = reconstruct_content(result, ["C", "A", "B"]) + assert "C" in rebuilt + assert "A" in rebuilt + assert "B" in rebuilt + assert rebuilt.startswith("Prefix\n") + assert rebuilt.endswith("\nSuffix") + assert "" in rebuilt + assert "" in rebuilt + + def test_numbered_roundtrip(self): + text = "[1] Alpha [2] Beta [3] Gamma" + config = InterceptConfig(mode="numbered") + result = extract_documents(text, config) + rebuilt = reconstruct_content(result, ["Gamma", "Alpha", "Beta"]) + assert "[1] Gamma" in rebuilt + assert "[2] Alpha" in rebuilt + assert "[3] Beta" in rebuilt + + def test_separator_roundtrip(self): + text = "Doc A\n---\nDoc B\n---\nDoc C" + config = InterceptConfig(mode="separator") + result = extract_documents(text, config) + rebuilt = reconstruct_content(result, ["Doc C", "Doc A", "Doc B"]) + parts = rebuilt.split("\n---\n") + assert parts == ["Doc C", "Doc A", "Doc B"] + + +# ============================================================================ +# OpenAI chat format +# ============================================================================ + + +class TestOpenAIChatFormat: + def test_extract_from_system_message(self): + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "AB"}, + {"role": "user", "content": "What is A?"}, + ], + } + config = InterceptConfig() + result = extract_from_openai_chat(body, config) + assert result is not None + extraction, idx = result + assert extraction.documents == ["A", "B"] + assert idx == 0 + + def test_no_system_message(self): + body = { + "messages": [ + {"role": "user", "content": "Hello"}, + ], + } + result = extract_from_openai_chat(body, InterceptConfig()) + assert result is None + + def test_system_without_docs(self): + body = { + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello"}, + ], + } + result = extract_from_openai_chat(body, InterceptConfig()) + assert result is None + + def test_reconstruct_roundtrip(self): + body = { + "model": "gpt-4", + "messages": [ + {"role": "system", "content": "\nA\nB\nC\n"}, + {"role": "user", "content": "Summarize"}, + ], + } + config = InterceptConfig() + extraction, idx = extract_from_openai_chat(body, config) + new_body = reconstruct_openai_chat(body, extraction, ["C", "A", "B"], idx) + # Original body not modified + assert "A" in body["messages"][0]["content"] + # New body has reordered docs + content = new_body["messages"][0]["content"] + assert "C" in content + # User message preserved + assert new_body["messages"][1]["content"] == "Summarize" + # Model preserved + assert new_body["model"] == "gpt-4" + + def test_content_blocks_format(self): + body = { + "model": "gpt-4", + "messages": [ + { + "role": "system", + "content": [ + {"type": "text", "text": "XY"} + ], + }, + {"role": "user", "content": "Question"}, + ], + } + config = InterceptConfig() + result = extract_from_openai_chat(body, config) + assert result is not None + extraction, idx = result + assert extraction.documents == ["X", "Y"] + + def test_empty_messages(self): + result = extract_from_openai_chat({"messages": []}, InterceptConfig()) + assert result is None + + def test_no_messages_key(self): + result = extract_from_openai_chat({"prompt": "hello"}, InterceptConfig()) + assert result is None + + +# ============================================================================ +# Anthropic messages format +# ============================================================================ + + +class TestAnthropicMessagesFormat: + def test_extract_string_system(self): + body = { + "model": "claude-3-opus-20240229", + "system": "AB", + "messages": [{"role": "user", "content": "Hello"}], + } + config = InterceptConfig() + result = extract_from_anthropic_messages(body, config) + assert result is not None + assert result.documents == ["A", "B"] + + def test_extract_content_blocks_system(self): + body = { + "model": "claude-3-opus-20240229", + "system": [ + {"type": "text", "text": "XY"} + ], + "messages": [{"role": "user", "content": "Hello"}], + } + config = InterceptConfig() + result = extract_from_anthropic_messages(body, config) + assert result is not None + assert result.documents == ["X", "Y"] + + def test_no_system_field(self): + body = {"messages": [{"role": "user", "content": "Hello"}]} + result = extract_from_anthropic_messages(body, InterceptConfig()) + assert result is None + + def test_reconstruct_string_system(self): + body = { + "system": "\nA\nB\n", + "messages": [{"role": "user", "content": "Hello"}], + } + config = InterceptConfig() + extraction = extract_from_anthropic_messages(body, config) + new_body = reconstruct_anthropic_messages(body, extraction, ["B", "A"]) + assert "B" in new_body["system"] + assert "A" in new_body["system"] + # Original not modified + assert body["system"].index("A") < body["system"].index("B") + + def test_reconstruct_content_blocks_system(self): + body = { + "system": [ + {"type": "text", "text": "PQ"} + ], + "messages": [{"role": "user", "content": "Hello"}], + } + config = InterceptConfig() + extraction = extract_from_anthropic_messages(body, config) + new_body = reconstruct_anthropic_messages(body, extraction, ["Q", "P"]) + text_block = new_body["system"][0]["text"] + assert "Q" in text_block + assert "P" in text_block + + +# ============================================================================ +# Markdown header extraction +# ============================================================================ + + +class TestMarkdownHeaderExtraction: + def test_basic_split(self): + text = "# Section A\nContent A\n\n# Section B\nContent B" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "markdown_header" + assert len(result.documents) == 2 + assert "# Section A" in result.documents[0] + assert "# Section B" in result.documents[1] + + def test_h2_headers(self): + text = "## Part 1\nText 1\n\n## Part 2\nText 2\n\n## Part 3\nText 3" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is not None + assert len(result.documents) == 3 + + def test_prefix_preserved(self): + text = "Some preamble\n\n# First\nA\n\n# Second\nB" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is not None + assert "preamble" in result.prefix + assert len(result.documents) == 2 + + def test_single_header_returns_none(self): + text = "# Only One Section\nSome content here" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is None + + def test_auto_priority_xml_over_markdown(self): + text = "AB" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "xml_tag" + + def test_auto_does_not_use_markdown(self): + """markdown_header is excluded from auto to avoid splitting structural prompts.""" + text = "# Topic A\nContent about A\n\n# Topic B\nContent about B" + config = InterceptConfig(mode="auto") + result = extract_documents(text, config) + assert result is None + + def test_explicit_mode_still_works(self): + text = "# Topic A\nContent about A\n\n# Topic B\nContent about B" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is not None + assert result.mode == "markdown_header" + + def test_roundtrip(self): + text = "Preamble\n\n# Alpha\nContent A\n\n# Beta\nContent B" + config = InterceptConfig(mode="markdown_header") + result = extract_documents(text, config) + assert result is not None + rebuilt = reconstruct_content(result, list(reversed(result.documents))) + assert "# Beta" in rebuilt + assert "# Alpha" in rebuilt + # Prefix preserved + assert "Preamble" in rebuilt + + +# ============================================================================ +# File XML tags +# ============================================================================ + + +class TestFileXmlTags: + def test_files_wrapper(self): + text = "file1.py contentfile2.py content" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "xml_tag" + assert result.wrapper_tag == "files" + assert result.item_tag == "file" + assert result.documents == ["file1.py content", "file2.py content"] + + def test_file_tags_without_wrapper(self): + text = "A\nB" + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.item_tag == "file" + + +# ============================================================================ +# OpenAI tool result extraction +# ============================================================================ + + +class TestOpenAIToolResultExtraction: + def test_single_tool_result(self): + body = { + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "user", "content": "Search for X"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "tool", "tool_call_id": "tc1", + "content": "Result AResult B"}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 1 + ext, loc = results[0] + assert ext.documents == ["Result A", "Result B"] + assert loc.msg_index == 3 + assert loc.block_index == -1 + + def test_multiple_tool_results(self): + body = { + "messages": [ + {"role": "system", "content": "You are a helper."}, + {"role": "tool", "tool_call_id": "tc1", + "content": "[1] Doc A [2] Doc B"}, + {"role": "tool", "tool_call_id": "tc2", + "content": "[1] Doc C [2] Doc D [3] Doc E"}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 2 + assert results[0][0].documents == ["Doc A", "Doc B"] + assert results[1][0].documents == ["Doc C", "Doc D", "Doc E"] + + def test_skip_no_docs(self): + body = { + "messages": [ + {"role": "tool", "tool_call_id": "tc1", + "content": "No documents here, just a plain response."}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 0 + + def test_skip_single_doc(self): + body = { + "messages": [ + {"role": "tool", "tool_call_id": "tc1", + "content": "Only one"}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 0 + + def test_content_blocks_format(self): + body = { + "messages": [ + {"role": "tool", "tool_call_id": "tc1", + "content": [ + {"type": "text", "text": "XY"} + ]}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 1 + assert results[0][1].block_index == 0 + + def test_camelcase_toolresult_role(self): + """OpenClaw internal format uses role='toolResult' (camelCase).""" + body = { + "messages": [ + {"role": "toolResult", "toolCallId": "tc1", + "content": [ + {"type": "text", "text": "AB"} + ]}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + assert len(results) == 1 + assert results[0][0].documents == ["A", "B"] + + +# ============================================================================ +# Anthropic tool result extraction +# ============================================================================ + + +class TestAnthropicToolResultExtraction: + def test_string_content(self): + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": "Res ARes B"}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + assert len(results) == 1 + ext, loc = results[0] + assert ext.documents == ["Res A", "Res B"] + assert loc.msg_index == 0 + assert loc.block_index == 0 + assert loc.inner_block_index == -1 + + def test_content_blocks(self): + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": [ + {"type": "text", "text": "[1] Alpha [2] Beta [3] Gamma"} + ]}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + assert len(results) == 1 + ext, loc = results[0] + assert ext.documents == ["Alpha", "Beta", "Gamma"] + assert loc.inner_block_index == 0 + + def test_no_tool_result(self): + body = { + "messages": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there"}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + assert len(results) == 0 + + def test_multiple_tool_results_in_one_message(self): + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": "AB"}, + {"type": "tool_result", "tool_use_id": "tu2", + "content": "CD"}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + assert len(results) == 2 + + def test_camelcase_toolresult_type(self): + """OpenClaw internal format uses type='toolResult' (camelCase).""" + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "toolResult", "toolUseId": "tu1", + "content": "XY"}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + assert len(results) == 1 + assert results[0][0].documents == ["X", "Y"] + + +# ============================================================================ +# Tool result reconstruction +# ============================================================================ + + +class TestToolResultReconstruction: + def test_openai_string_roundtrip(self): + body = { + "messages": [ + {"role": "tool", "tool_call_id": "tc1", + "content": "AB"}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + ext, loc = results[0] + body_copy = copy.deepcopy(body) + reconstruct_openai_tool_result(body_copy, ext, ["B", "A"], loc) + content = body_copy["messages"][0]["content"] + assert "B" in content + assert content.index("B") < content.index("A") + + def test_openai_blocks_roundtrip(self): + body = { + "messages": [ + {"role": "tool", "tool_call_id": "tc1", + "content": [ + {"type": "text", "text": "[1] X [2] Y [3] Z"} + ]}, + ], + } + config = InterceptConfig() + results = extract_from_openai_tool_results(body, config) + ext, loc = results[0] + body_copy = copy.deepcopy(body) + reconstruct_openai_tool_result(body_copy, ext, ["Z", "X", "Y"], loc) + text = body_copy["messages"][0]["content"][0]["text"] + assert "[1] Z" in text + + def test_anthropic_string_roundtrip(self): + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": "PQ"}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + ext, loc = results[0] + body_copy = copy.deepcopy(body) + reconstruct_anthropic_tool_result(body_copy, ext, ["Q", "P"], loc) + content = body_copy["messages"][0]["content"][0]["content"] + assert "Q" in content + + def test_anthropic_nested_blocks_roundtrip(self): + body = { + "messages": [ + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": [ + {"type": "text", "text": "[1] First [2] Second"} + ]}, + ]}, + ], + } + config = InterceptConfig() + results = extract_from_anthropic_tool_results(body, config) + ext, loc = results[0] + body_copy = copy.deepcopy(body) + reconstruct_anthropic_tool_result(body_copy, ext, ["Second", "First"], loc) + text = body_copy["messages"][0]["content"][0]["content"][0]["text"] + assert "[1] Second" in text + + +# ============================================================================ +# Scope filtering +# ============================================================================ + + +class TestScopeFiltering: + def _make_openai_body(self): + return { + "messages": [ + {"role": "system", "content": "SysASysB"}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": None, "tool_calls": [{"id": "tc1"}]}, + {"role": "tool", "tool_call_id": "tc1", + "content": "ToolAToolB"}, + ], + } + + def _make_anthropic_body(self): + return { + "system": "SysXSysY", + "messages": [ + {"role": "user", "content": [ + {"type": "text", "text": "Hello"}, + ]}, + {"role": "assistant", "content": [ + {"type": "tool_use", "id": "tu1", "name": "search", "input": {}}, + ]}, + {"role": "user", "content": [ + {"type": "tool_result", "tool_use_id": "tu1", + "content": "ToolXToolY"}, + ]}, + ], + } + + def test_openai_scope_system_only(self): + body = self._make_openai_body() + config = InterceptConfig(scope="system") + result = extract_all_openai(body, config) + assert result.system_extraction is not None + assert len(result.tool_extractions) == 0 + + def test_openai_scope_tool_results_only(self): + body = self._make_openai_body() + config = InterceptConfig(scope="tool_results") + result = extract_all_openai(body, config) + assert result.system_extraction is None + assert len(result.tool_extractions) == 1 + + def test_openai_scope_all(self): + body = self._make_openai_body() + config = InterceptConfig(scope="all") + result = extract_all_openai(body, config) + assert result.system_extraction is not None + assert len(result.tool_extractions) == 1 + assert result.total_documents == 4 + + def test_anthropic_scope_system_only(self): + body = self._make_anthropic_body() + config = InterceptConfig(scope="system") + result = extract_all_anthropic(body, config) + assert result.system_extraction is not None + assert len(result.tool_extractions) == 0 + + def test_anthropic_scope_tool_results_only(self): + body = self._make_anthropic_body() + config = InterceptConfig(scope="tool_results") + result = extract_all_anthropic(body, config) + assert result.system_extraction is None + assert len(result.tool_extractions) == 1 + + def test_anthropic_scope_all(self): + body = self._make_anthropic_body() + config = InterceptConfig(scope="all") + result = extract_all_anthropic(body, config) + assert result.system_extraction is not None + assert len(result.tool_extractions) == 1 + assert result.total_documents == 4 + + +# ============================================================================ +# OpenClaw system prompt pattern +# ============================================================================ + + +class TestOpenClawSystemPromptPattern: + def test_openclaw_files_pattern(self): + """Simulate OpenClaw's system prompt with tags.""" + text = ( + "You are an AI assistant.\n" + "\n" + "path/to/file1.py\nclass Foo:\n pass\n" + "path/to/file2.py\ndef bar():\n return 42\n" + "path/to/file3.py\nimport os\nos.path.join('a', 'b')\n" + "\n" + "Answer the user's questions based on the files above." + ) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "xml_tag" + assert result.wrapper_tag == "files" + assert result.item_tag == "file" + assert len(result.documents) == 3 + assert "class Foo" in result.documents[0] + assert "Answer the user" in result.suffix + + def test_openclaw_json_tool_result_pattern(self): + """Simulate OpenClaw's memory search tool result (JSON with results array).""" + import json + text = json.dumps({ + "results": [ + {"path": "intro.md", "snippet": "Python is a high-level language.", "score": 0.9}, + {"path": "types.md", "snippet": "Python has several data types.", "score": 0.8}, + {"path": "funcs.md", "snippet": "Functions are reusable code blocks.", "score": 0.7}, + ], + "citations": "auto", + }, indent=2) + config = InterceptConfig() + result = extract_documents(text, config) + assert result is not None + assert result.mode == "json_results" + assert len(result.documents) == 3 diff --git a/tests/test_live_index.py b/tests/test_live_index.py index 027c230..17c13c7 100644 --- a/tests/test_live_index.py +++ b/tests/test_live_index.py @@ -11,29 +11,29 @@ class TestLiveIndexInitialization: """Test live index initialization.""" - + def test_live_index_creation(self): """Test basic live index creation.""" from contextpilot import ContextPilot - + index = ContextPilot( alpha=0.001, use_gpu=False, ) - + assert index is not None assert index.is_live is False - + def test_live_index_with_different_configs(self): """Test live index with various configurations.""" from contextpilot import ContextPilot - + configs = [ {"alpha": 0.001}, {"alpha": 0.01}, {"alpha": 0.001}, ] - + for config in configs: index = ContextPilot(use_gpu=False, **config) assert index is not None @@ -41,150 +41,99 @@ def test_live_index_with_different_configs(self): class TestBuildAndSchedule: """Test build and schedule functionality.""" - + def test_build_and_schedule(self): """Test building and scheduling contexts.""" from contextpilot import ContextPilot - + index = ContextPilot(use_gpu=False) - + contexts = [ [1, 2, 3, 4, 5], [1, 2, 3, 6, 7], [8, 9, 10, 11, 12], ] - + result = index.build_and_schedule(contexts) - + assert result is not None assert index.initial_result is not None assert index.scheduled_result is not None - + def test_index_becomes_live_after_build(self): """Test that index becomes live after build_and_schedule.""" from contextpilot import ContextPilot - + index = ContextPilot(use_gpu=False) - + contexts = [[1, 2, 3], [4, 5, 6]] index.build_and_schedule(contexts) - + # build_and_schedule automatically sets is_live = True assert index.is_live is True - + def test_schedule_only_stateless(self): """Test schedule_only for stateless mode.""" from contextpilot import ContextPilot - + index = ContextPilot(use_gpu=False) - + contexts = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] result = index.schedule_only(contexts) - + assert result is not None - assert 'reordered_contexts' in result - assert 'scheduled_originals' in result - assert 'original_indices' in result + assert "reordered_contexts" in result + assert "scheduled_originals" in result + assert "original_indices" in result # In stateless mode, is_live should remain False assert index.is_live is False -class TestEvictionHeap: - """Test eviction heap functionality.""" - - def test_eviction_heap_initialization(self): - """Test eviction heap initializes correctly.""" - from contextpilot.server.eviction_heap import EvictionHeap - - heap = EvictionHeap(max_tokens=10000) - - assert heap is not None - assert heap.max_tokens == 10000 - - def test_eviction_heap_push(self): - """Test pushing metadata to eviction heap.""" - from contextpilot.server.eviction_heap import EvictionHeap - from contextpilot.server.metadata import NodeMetadata - - heap = EvictionHeap(max_tokens=10000) - - metadata = NodeMetadata(node_id=1, total_tokens=100, extra_tokens=50) - heap.push(metadata) - - assert len(heap) == 1 - - def test_eviction_heap_pop(self): - """Test popping from eviction heap.""" - from contextpilot.server.eviction_heap import EvictionHeap - from contextpilot.server.metadata import NodeMetadata - import time - - heap = EvictionHeap(max_tokens=10000) - - # Add items with different timestamps - m1 = NodeMetadata(node_id=1, total_tokens=100, extra_tokens=50) - m1.last_access_time = time.time() - 100 # Oldest - - m2 = NodeMetadata(node_id=2, total_tokens=100, extra_tokens=50) - m2.last_access_time = time.time() # Newest - - heap.push(m1) - heap.push(m2) - - # Pop should return oldest (LRU) - popped = heap.pop() - assert popped.node_id == 1 - - class TestNodeMetadata: """Test node metadata handling.""" - + def test_metadata_creation(self): """Test creating node metadata.""" from contextpilot.server.metadata import NodeMetadata - - metadata = NodeMetadata( - node_id=1, - total_tokens=100, - extra_tokens=50 - ) - + + metadata = NodeMetadata(node_id=1, total_tokens=100, extra_tokens=50) + assert metadata.node_id == 1 assert metadata.total_tokens == 100 assert metadata.extra_tokens == 50 - + def test_metadata_access_time_update(self): """Test updating access time.""" from contextpilot.server.metadata import NodeMetadata import time - + metadata = NodeMetadata(node_id=1) old_time = metadata.last_access_time - + time.sleep(0.01) metadata.update_access_time() - + assert metadata.last_access_time > old_time - + def test_metadata_add_tokens(self): """Test adding tokens to metadata.""" from contextpilot.server.metadata import NodeMetadata - + metadata = NodeMetadata(node_id=1, total_tokens=100, extra_tokens=50) - + metadata.add_tokens(25) - + assert metadata.total_tokens == 125 assert metadata.extra_tokens == 75 - + def test_metadata_remove_tokens(self): """Test removing tokens from metadata.""" from contextpilot.server.metadata import NodeMetadata - + metadata = NodeMetadata(node_id=1, total_tokens=100, extra_tokens=50) - + removed = metadata.remove_tokens(30) - + assert removed == 30 assert metadata.extra_tokens == 20 assert metadata.total_tokens == 70 @@ -192,68 +141,68 @@ def test_metadata_remove_tokens(self): class TestComputePrefixLength: """Test prefix length computation utility.""" - + def test_identical_lists(self): """Identical lists should have full prefix length.""" from contextpilot.server.live_index import compute_prefix_length - + list1 = [1, 2, 3, 4, 5] list2 = [1, 2, 3, 4, 5] - + assert compute_prefix_length(list1, list2) == 5 - + def test_partial_prefix(self): """Lists with partial prefix should return correct length.""" from contextpilot.server.live_index import compute_prefix_length - + list1 = [1, 2, 3, 100, 200] list2 = [1, 2, 3, 300, 400] - + assert compute_prefix_length(list1, list2) == 3 - + def test_no_common_prefix(self): """Lists with no common prefix should return 0.""" from contextpilot.server.live_index import compute_prefix_length - + list1 = [1, 2, 3] list2 = [4, 5, 6] - + assert compute_prefix_length(list1, list2) == 0 - + def test_empty_list(self): """Empty list should have 0 prefix length.""" from contextpilot.server.live_index import compute_prefix_length - + assert compute_prefix_length([], [1, 2, 3]) == 0 assert compute_prefix_length([1, 2, 3], []) == 0 assert compute_prefix_length([], []) == 0 - + def test_different_length_lists(self): """Lists of different lengths should work correctly.""" from contextpilot.server.live_index import compute_prefix_length - + list1 = [1, 2, 3] list2 = [1, 2, 3, 4, 5] - + assert compute_prefix_length(list1, list2) == 3 class TestLiveIndexRequestTracking: """Test request tracking in live index.""" - + def test_request_id_auto_generated(self): """Test that request IDs are auto-generated during build.""" from contextpilot import ContextPilot - + index = ContextPilot(use_gpu=False) - + contexts = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] result = index.build_and_schedule(contexts) - + # Should have request_id_mapping in result - assert 'request_id_mapping' in result - assert 'request_ids' in result - assert len(result['request_ids']) == len(contexts) + assert "request_id_mapping" in result + assert "request_ids" in result + assert len(result["request_ids"]) == len(contexts) def test_reorder_single_list(self): """reorder() should accept a single list and auto-wrap it.""" diff --git a/tests/test_ttl_eviction.py b/tests/test_ttl_eviction.py new file mode 100644 index 0000000..22edd71 --- /dev/null +++ b/tests/test_ttl_eviction.py @@ -0,0 +1,297 @@ +"""Tests for TTL-based eviction policy.""" + +import time +import threading +from unittest.mock import patch + +import pytest + +from contextpilot.server.ttl_eviction import ( + TTLEvictionPolicy, + TTLTier, + CacheEntry, + CacheMetrics, +) + + +class TestTTLTier: + def test_short_tier_seconds(self): + assert TTLTier.SHORT.seconds == 300 + + def test_medium_tier_seconds(self): + assert TTLTier.MEDIUM.seconds == 3600 + + def test_long_tier_seconds(self): + assert TTLTier.LONG.seconds == 86400 + + def test_tier_values(self): + assert TTLTier.SHORT.value == "5m" + assert TTLTier.MEDIUM.value == "1h" + assert TTLTier.LONG.value == "24h" + + def test_tier_from_string(self): + assert TTLTier("5m") is TTLTier.SHORT + assert TTLTier("1h") is TTLTier.MEDIUM + assert TTLTier("24h") is TTLTier.LONG + + +class TestCacheEntry: + def test_not_expired_when_fresh(self): + now = time.time() + entry = CacheEntry( + content_hash="abc", + request_id="req-abc", + created_at=now, + last_accessed_at=now, + ttl_seconds=300, + ) + assert not entry.is_expired(now) + + def test_expired_after_ttl(self): + now = time.time() + entry = CacheEntry( + content_hash="abc", + request_id="req-abc", + created_at=now - 400, + last_accessed_at=now - 400, + ttl_seconds=300, + ) + assert entry.is_expired(now) + + def test_time_remaining_positive(self): + now = time.time() + entry = CacheEntry( + content_hash="abc", + request_id="req-abc", + created_at=now, + last_accessed_at=now, + ttl_seconds=300, + ) + assert entry.time_remaining(now) == pytest.approx(300, abs=1) + + def test_time_remaining_negative_when_expired(self): + now = time.time() + entry = CacheEntry( + content_hash="abc", + request_id="req-abc", + created_at=now - 400, + last_accessed_at=now - 400, + ttl_seconds=300, + ) + assert entry.time_remaining(now) < 0 + + +class TestCacheMetrics: + def test_defaults(self): + m = CacheMetrics() + assert m.cache_creation_tokens == 0 + assert m.cache_read_tokens == 0 + assert m.input_tokens == 0 + assert m.output_tokens == 0 + + def test_custom_values(self): + m = CacheMetrics( + cache_creation_tokens=1000, + cache_read_tokens=500, + input_tokens=1500, + output_tokens=200, + ) + assert m.cache_creation_tokens == 1000 + assert m.cache_read_tokens == 500 + + +class TestTTLEvictionPolicy: + def test_add_entry(self): + policy = TTLEvictionPolicy() + entry = policy.add_entry("hash1", content_hash="hash1", token_count=5000) + assert entry is not None + assert entry.content_hash == "hash1" + assert entry.token_count == 5000 + assert policy.get_cached_count() == 1 + + def test_add_entry_with_custom_ttl_seconds(self): + policy = TTLEvictionPolicy(default_ttl=TTLTier.LONG, default_ttl_seconds=86400) + entry = policy.add_entry("hash1") + assert entry is not None + assert entry.ttl_seconds == 86400 + + def test_add_existing_refreshes(self): + policy = TTLEvictionPolicy() + e1 = policy.add_entry("hash1", token_count=100) + assert e1 is not None + t1 = e1.last_accessed_at + time.sleep(0.01) + e2 = policy.add_entry("hash1", token_count=200) + assert e2 is not None + assert e2.last_accessed_at > t1 + assert e2.token_count == 200 + assert policy.get_cached_count() == 1 + + def test_is_cached_true(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1") + assert policy.is_cached("hash1") + + def test_is_cached_false_not_found(self): + policy = TTLEvictionPolicy() + assert not policy.is_cached("nonexistent") + + def test_is_cached_false_after_expiry(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1") + now = time.time() + 600 + with patch("contextpilot.server.ttl_eviction.time.time", return_value=now): + assert not policy.is_cached("hash1") + + def test_touch_entry(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1") + assert policy.touch_entry("hash1") + assert not policy.touch_entry("nonexistent") + + def test_touch_expired_returns_false(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1") + now = time.time() + 600 + with patch("contextpilot.server.ttl_eviction.time.time", return_value=now): + assert not policy.touch_entry("hash1") + + def test_evict_expired(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1", token_count=100) + policy.add_entry("hash2", token_count=200) + + now = time.time() + evicted = policy.evict_expired() + assert len(evicted) == 0 + + future = now + 600 + with patch("contextpilot.server.ttl_eviction.time.time", return_value=future): + evicted = policy.evict_expired() + assert len(evicted) == 2 + + def test_evict_only_expired_short_vs_long(self): + short_policy = TTLEvictionPolicy( + default_ttl=TTLTier.SHORT, default_ttl_seconds=300 + ) + short_policy.add_entry("short_hash", content_hash="short_hash") + + long_policy = TTLEvictionPolicy( + default_ttl=TTLTier.LONG, default_ttl_seconds=3600 + ) + long_policy.add_entry("long_hash", content_hash="long_hash") + + future_6min = time.time() + 360 + with patch( + "contextpilot.server.ttl_eviction.time.time", return_value=future_6min + ): + evicted_short = short_policy.evict_expired() + evicted_long = long_policy.evict_expired() + assert len(evicted_short) == 1 + assert evicted_short[0].content_hash == "short_hash" + assert len(evicted_long) == 0 + + def test_get_cached_hashes(self): + policy = TTLEvictionPolicy() + policy.add_entry("a", content_hash="a") + policy.add_entry("b", content_hash="b") + policy.add_entry("c", content_hash="c") + hashes = policy.get_cached_hashes() + assert hashes == {"a", "b", "c"} + + def test_get_total_cached_tokens(self): + policy = TTLEvictionPolicy() + policy.add_entry("a", token_count=100) + policy.add_entry("b", token_count=200) + assert policy.get_total_cached_tokens() == 300 + + def test_update_from_response_cache_hit(self): + policy = TTLEvictionPolicy() + policy.add_entry("hash1", token_count=100) + metrics = CacheMetrics(cache_read_tokens=100) + policy.update_from_response(metrics, "hash1") + stats = policy.get_stats() + assert stats["total_hits"] >= 1 + + def test_update_from_response_cache_write(self): + policy = TTLEvictionPolicy() + metrics = CacheMetrics(cache_creation_tokens=5000) + policy.update_from_response(metrics, "new_hash") + assert policy.is_cached("new_hash") + + def test_reset(self): + policy = TTLEvictionPolicy() + policy.add_entry("a") + policy.add_entry("b") + policy.reset() + assert policy.get_cached_count() == 0 + stats = policy.get_stats() + assert stats["total_hits"] == 0 + assert stats["total_misses"] == 0 + + def test_get_stats(self): + policy = TTLEvictionPolicy() + policy.add_entry("a", token_count=100) + policy.is_cached("a") + policy.is_cached("missing") + + stats = policy.get_stats() + assert stats["active_entries"] == 1 + assert stats["total_cached_tokens"] == 100 + assert stats["total_hits"] >= 1 + assert stats["total_misses"] >= 1 + assert stats["default_ttl"] == "5m" + assert stats["default_ttl_seconds"] == 300 + + def test_default_ttl_property(self): + policy = TTLEvictionPolicy(default_ttl=TTLTier.LONG) + assert policy.default_ttl == TTLTier.LONG + policy.default_ttl = TTLTier.SHORT + assert policy.default_ttl == TTLTier.SHORT + + def test_len(self): + policy = TTLEvictionPolicy() + assert len(policy) == 0 + policy.add_entry("a") + policy.add_entry("b") + assert len(policy) == 2 + + def test_repr(self): + policy = TTLEvictionPolicy() + policy.add_entry("a") + r = repr(policy) + assert "active=1" in r + assert "default_ttl=5m" in r + + def test_thread_safety(self): + policy = TTLEvictionPolicy() + errors = [] + + def writer(prefix, count): + try: + for i in range(count): + policy.add_entry(f"{prefix}_{i}", token_count=i) + except Exception as e: + errors.append(e) + + def reader(count): + try: + for _ in range(count): + policy.get_cached_hashes() + policy.get_stats() + policy.evict_expired() + except Exception as e: + errors.append(e) + + threads = [ + threading.Thread(target=writer, args=("a", 100)), + threading.Thread(target=writer, args=("b", 100)), + threading.Thread(target=reader, args=(50,)), + ] + for t in threads: + t.start() + for t in threads: + t.join() + + assert not errors, f"Thread safety errors: {errors}" + assert policy.get_cached_count() == 200