From 2ae71092fde77274d62a508052f6e55029920aaa Mon Sep 17 00:00:00 2001 From: Etan Joseph Heyman Date: Fri, 19 Jun 2026 19:17:40 +0300 Subject: [PATCH] fix: fall back to FTS when search embeddings stall --- hooks/brainlayer-prompt-search.py | 37 ++++++++++- src/brainlayer/mcp/search_handler.py | 69 +++++++++++++++++-- src/brainlayer/search_repo.py | 15 ++++- tests/test_adaptive_injection.py | 43 ++++++++++++ tests/test_hybrid_search.py | 17 +++++ tests/test_search_handler.py | 99 +++++++++++++++++++++++++++- 6 files changed, 270 insertions(+), 10 deletions(-) diff --git a/hooks/brainlayer-prompt-search.py b/hooks/brainlayer-prompt-search.py index b05f198b..e63120fa 100755 --- a/hooks/brainlayer-prompt-search.py +++ b/hooks/brainlayer-prompt-search.py @@ -13,9 +13,11 @@ import json import os +import queue import re import sqlite3 import sys +import threading import time from datetime import datetime, timedelta, timezone from hashlib import sha256 @@ -34,6 +36,7 @@ MAX_ADAPTIVE_INJECTION = 3 MAX_HYBRID_CANDIDATES = 8 DEGRADED_PREFIX = "⚠️ DEGRADED: BrainLayer" +DEFAULT_EMBED_TIMEOUT_MS = 1000.0 def degraded_notice(reason): @@ -44,6 +47,38 @@ def emit_degraded(reason): print(degraded_notice(reason)) +def embed_timeout_ms(): + raw = os.environ.get("BRAINLAYER_EMBED_TIMEOUT_MS", str(DEFAULT_EMBED_TIMEOUT_MS)) + try: + value = float(raw) + except (TypeError, ValueError): + return DEFAULT_EMBED_TIMEOUT_MS + if not value or value < 0: + return DEFAULT_EMBED_TIMEOUT_MS + return min(value, 30_000.0) + + +def run_with_timeout(func, timeout_ms, *args, **kwargs): + results = queue.Queue(maxsize=1) + + def target(): + try: + results.put((True, func(*args, **kwargs))) + except BaseException as exc: + results.put((False, exc)) + + thread = threading.Thread(target=target, name="brainlayer-hook-hybrid", daemon=True) + thread.start() + thread.join(timeout_ms / 1000.0) + if thread.is_alive(): + raise TimeoutError(f"hybrid search exceeded {timeout_ms:.0f}ms") + + ok, payload = results.get_nowait() + if ok: + return payload + raise payload + + # Prompts shorter than this are probably greetings/commands — skip search MIN_PROMPT_LENGTH = 15 HEBREW_CANDIDATE_RE = re.compile(r"[\u0590-\u05FF]{2,}") @@ -944,7 +979,7 @@ def run_fts_search(db_path, keywords, limit): def search_prompt_chunks(prompt, db_path, keywords, limit): """Search with hybrid first, then fall back to FTS-only behavior.""" try: - return run_hybrid_search(prompt, db_path, keywords, limit), True + return run_with_timeout(run_hybrid_search, embed_timeout_ms(), prompt, db_path, keywords, limit), True except Exception: return run_fts_search(db_path, keywords, limit), False diff --git a/src/brainlayer/mcp/search_handler.py b/src/brainlayer/mcp/search_handler.py index e570d922..83b09eeb 100644 --- a/src/brainlayer/mcp/search_handler.py +++ b/src/brainlayer/mcp/search_handler.py @@ -34,6 +34,7 @@ _ORIGIN_ORDER_LABEL = "- Order: origin (earliest among expanded hybrid candidates)" _HELPER_SOCKET_GLOB = "/tmp/brainbar-hybrid-*.sock" _HELPER_SOCKET_TIMEOUT_SECONDS = 2.0 +_DEFAULT_EMBED_TIMEOUT_MS = 1000.0 from ._format import format_kg_search, format_recalled_context, format_search_results, format_stats from ._shared import ( @@ -60,6 +61,17 @@ def _get_vector_store(): return _get_search_vector_store() +def _embed_timeout_ms() -> float: + raw = os.environ.get("BRAINLAYER_EMBED_TIMEOUT_MS", str(_DEFAULT_EMBED_TIMEOUT_MS)) + try: + value = float(raw) + except (TypeError, ValueError): + return _DEFAULT_EMBED_TIMEOUT_MS + if not value or value < 0: + return _DEFAULT_EMBED_TIMEOUT_MS + return min(value, 30_000.0) + + def _origin_candidate_count(num_results: int) -> int: return min(_MAX_PUBLIC_NUM_RESULTS, max(num_results, _ORIGIN_CANDIDATE_LIMIT)) @@ -1767,20 +1779,44 @@ async def _search( normalized_project = _normalize_project_name(project) loop = asyncio.get_running_loop() - model = _get_embedding_model() embed_started = search_profile.now() + query_embedding = None + search_mode = "hybrid" + fallback_reason = None try: - query_embedding = await loop.run_in_executor(None, model.embed_query, query) + embed_timeout_ms = _embed_timeout_ms() + query_embedding = await asyncio.wait_for( + loop.run_in_executor( + None, + lambda: _get_embedding_model().embed_query(query), + ), + timeout=embed_timeout_ms / 1000.0, + ) + except TimeoutError as exc: + search_mode = "fts_only" + fallback_reason = "embed_timeout" + search_profile.emit( + profile_scope, + "embed", + profile_query_id, + search_profile.dur_ms(embed_started), + error=exc.__class__.__name__, + timeout_ms=embed_timeout_ms, + fallback="fts_only", + ) except Exception as exc: + search_mode = "fts_only" + fallback_reason = f"embed_error:{exc.__class__.__name__}" search_profile.emit( profile_scope, "embed", profile_query_id, search_profile.dur_ms(embed_started), error=exc.__class__.__name__, + fallback="fts_only", ) - raise - search_profile.emit(profile_scope, "embed", profile_query_id, search_profile.dur_ms(embed_started)) + else: + search_profile.emit(profile_scope, "embed", profile_query_id, search_profile.dur_ms(embed_started)) if source == "all": source_filter = None @@ -1890,7 +1926,14 @@ async def _search( } ) structured_results.append(item) - structured = {"query": query, "total": len(structured_results), "results": structured_results} + structured = { + "query": query, + "total": len(structured_results), + "results": structured_results, + "search_mode": search_mode, + } + if fallback_reason: + structured["fallback_reason"] = fallback_reason if order == "origin": structured["order"] = order structured["order_scope"] = _ORIGIN_ORDER_SCOPE @@ -1900,9 +1943,16 @@ async def _search( len(structured_results), order=order if order == "origin" else None, ) + if search_mode == "fts_only": + formatted_text = ( + f"{formatted_text}\n\n" + f"Search mode: FTS-only fallback ({fallback_reason}); vector embedding was skipped." + ) return ([TextContent(type="text", text=formatted_text)], structured) output_parts = [f"## Search Results for: {query}\n"] + if search_mode == "fts_only": + output_parts.append(f"Search mode: FTS-only fallback ({fallback_reason}); vector embedding was skipped.") if order == "origin": output_parts.append(_ORIGIN_ORDER_LABEL) structured_results = [] @@ -1973,7 +2023,14 @@ async def _search( output_parts.append(doc) output_parts.append("\n---") - structured = {"query": query, "total": len(structured_results), "results": structured_results} + structured = { + "query": query, + "total": len(structured_results), + "results": structured_results, + "search_mode": search_mode, + } + if fallback_reason: + structured["fallback_reason"] = fallback_reason if order == "origin": structured["order"] = order structured["order_scope"] = _ORIGIN_ORDER_SCOPE diff --git a/src/brainlayer/search_repo.py b/src/brainlayer/search_repo.py index a4611487..df80d8ce 100644 --- a/src/brainlayer/search_repo.py +++ b/src/brainlayer/search_repo.py @@ -1711,7 +1711,7 @@ def _rerank_binary_results_with_float( def hybrid_search( self, - query_embedding: List[float], + query_embedding: Optional[List[float]], query_text: str, fts_query_override: Optional[str] = None, n_results: int = 10, @@ -1810,8 +1810,19 @@ def hybrid_search( # 1. Semantic search leg — prefer binary vectors, fall back to float vectors # when the binary index is unavailable (for example readonly live DBs). + # A None embedding is an intentional FTS-only fallback: query embedding + # must never block lexical reads. candidate_fetch_count = max(n_results * 3, _MMR_CANDIDATE_LIMIT) - if getattr(self, "_binary_index_available", False): + if query_embedding is None: + semantic = {"ids": [[]], "documents": [[]], "metadatas": [[]], "distances": [[]]} + search_profile.emit( + profile_scope, + "semantic_skip", + profile_query_id, + 0.0, + reason="missing_query_embedding", + ) + elif getattr(self, "_binary_index_available", False): binary_started = search_profile.now() semantic = self._binary_search( query_embedding=query_embedding, diff --git a/tests/test_adaptive_injection.py b/tests/test_adaptive_injection.py index b5a70f96..1b83d7f5 100644 --- a/tests/test_adaptive_injection.py +++ b/tests/test_adaptive_injection.py @@ -1,6 +1,7 @@ """Tests for score-based adaptive prompt injection in the BrainLayer hook.""" import importlib.util +import time from pathlib import Path import pytest @@ -155,6 +156,48 @@ def test_fallback_to_fts_only(self, prompt_search, monkeypatch): assert used_hybrid is False assert [row["id"] for row in rows] == ["fts-best"] + def test_slow_hybrid_search_falls_back_to_fts_only_within_timeout(self, prompt_search, monkeypatch): + fts_rows = [_row("fts-timeout", 0.0)] + + def slow_hybrid(*args, **kwargs): + time.sleep(0.05) + return [_row("late-hybrid", 0.02)] + + monkeypatch.setenv("BRAINLAYER_EMBED_TIMEOUT_MS", "1") + monkeypatch.setattr(prompt_search, "run_hybrid_search", slow_hybrid) + monkeypatch.setattr(prompt_search, "run_fts_search", lambda *args, **kwargs: fts_rows) + + started = time.monotonic() + rows, used_hybrid = prompt_search.search_prompt_chunks( + prompt="keyword fallback query", + db_path="/tmp/test.db", + keywords=["keyword", "fallback"], + limit=8, + ) + elapsed = time.monotonic() - started + + assert elapsed < 0.5 + assert used_hybrid is False + assert [row["id"] for row in rows] == ["fts-timeout"] + + def test_fast_hybrid_search_stays_on_hybrid_path(self, prompt_search, monkeypatch): + hybrid_rows = [_row("hybrid-best", 0.02)] + fts_rows = [_row("fts-unused", 0.0)] + + monkeypatch.setenv("BRAINLAYER_EMBED_TIMEOUT_MS", "1000") + monkeypatch.setattr(prompt_search, "run_hybrid_search", lambda *args, **kwargs: hybrid_rows) + monkeypatch.setattr(prompt_search, "run_fts_search", lambda *args, **kwargs: fts_rows) + + rows, used_hybrid = prompt_search.search_prompt_chunks( + prompt="keyword fallback query", + db_path="/tmp/test.db", + keywords=["keyword", "fallback"], + limit=8, + ) + + assert used_hybrid is True + assert [row["id"] for row in rows] == ["hybrid-best"] + def test_hybrid_search_opens_vector_store_readonly(self, prompt_search, monkeypatch, tmp_path): opened = [] diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py index 8e308647..9d184ee5 100644 --- a/tests/test_hybrid_search.py +++ b/tests/test_hybrid_search.py @@ -217,6 +217,23 @@ def test_hybrid_search_fts_only_fallback(self, store): assert "fts-hit" in results["ids"][0] + def test_hybrid_search_accepts_none_embedding_for_fts_only_fallback(self, store): + _insert_chunk( + store, + chunk_id="fts-none-embedding-hit", + content="exact keyword fallback when embedding is unavailable", + embedding=_embed("distant vector"), + ) + store.build_binary_index() + + results = store.hybrid_search( + query_embedding=None, + query_text="embedding unavailable", + n_results=5, + ) + + assert results["ids"][0] == ["fts-none-embedding-hit"] + def test_hybrid_search_fts_only_returns_provenance_metadata(self, store): cursor = store.conn.cursor() columns = {row[1] for row in cursor.execute("PRAGMA table_info(chunks)")} diff --git a/tests/test_search_handler.py b/tests/test_search_handler.py index 1e8716f5..3fdcd393 100644 --- a/tests/test_search_handler.py +++ b/tests/test_search_handler.py @@ -1,7 +1,9 @@ +import time + import pytest from brainlayer.mcp import call_tool -from brainlayer.mcp.search_handler import _brain_search +from brainlayer.mcp.search_handler import _brain_search, _search from brainlayer.vector_store import VectorStore @@ -10,6 +12,17 @@ def embed_query(self, _query: str) -> list[float]: return [0.1, 0.2, 0.3] +class SlowEmbeddingModel: + def embed_query(self, _query: str) -> list[float]: + time.sleep(0.05) + return [0.9, 0.9, 0.9] + + +class RaisingEmbeddingModel: + def embed_query(self, _query: str) -> list[float]: + raise RuntimeError("embedding backend unavailable") + + class OriginEmbeddingModel: def __init__(self, embedding: list[float]) -> None: self.embedding = embedding @@ -38,6 +51,24 @@ def enrich_results_with_session_context(self, results): return results +class FtsFallbackSearchStore(RecordingSearchStore): + def hybrid_search(self, **kwargs): + self.hybrid_kwargs = kwargs + if kwargs["query_embedding"] is None: + return { + "ids": [["fts-fallback-hit"]], + "documents": [["keyword fallback result from FTS"]], + "metadatas": [[{"source_file": "fts.md", "project": "brainlayer"}]], + "distances": [[1.0]], + } + return { + "ids": [["hybrid-hit"]], + "documents": [["semantic and keyword hybrid result"]], + "metadatas": [[{"source_file": "hybrid.md", "project": "brainlayer"}]], + "distances": [[0.25]], + } + + class RecordingKgSearchStore(RecordingSearchStore): def __init__(self) -> None: super().__init__() @@ -136,6 +167,72 @@ def _seed_origin_search_store(db_path) -> tuple[VectorStore, list[float]]: return store, query_embedding +@pytest.mark.asyncio +async def test_brain_search_falls_back_to_fts_when_embedding_times_out(monkeypatch): + store = FtsFallbackSearchStore() + + monkeypatch.setenv("BRAINLAYER_EMBED_TIMEOUT_MS", "1") + monkeypatch.setattr("brainlayer.mcp.search_handler._get_vector_store", lambda: store) + monkeypatch.setattr("brainlayer.mcp.search_handler._get_embedding_model", lambda: SlowEmbeddingModel()) + monkeypatch.setattr("brainlayer.mcp.search_handler._normalize_project_name", lambda project: project) + + started = time.monotonic() + content, structured = await _search( + query="keyword fallback", + source="all", + num_results=1, + ) + elapsed = time.monotonic() - started + + assert elapsed < 0.5 + assert store.hybrid_kwargs["query_embedding"] is None + assert structured["search_mode"] == "fts_only" + assert structured["fallback_reason"] == "embed_timeout" + assert [item["chunk_id"] for item in structured["results"]] == ["fts-fallback-hit"] + assert "FTS-only fallback" in content[0].text + + +@pytest.mark.asyncio +async def test_brain_search_falls_back_to_fts_when_embedding_raises(monkeypatch): + store = FtsFallbackSearchStore() + + monkeypatch.setattr("brainlayer.mcp.search_handler._get_vector_store", lambda: store) + monkeypatch.setattr("brainlayer.mcp.search_handler._get_embedding_model", lambda: RaisingEmbeddingModel()) + monkeypatch.setattr("brainlayer.mcp.search_handler._normalize_project_name", lambda project: project) + + _content, structured = await _search( + query="keyword fallback", + source="all", + num_results=1, + ) + + assert store.hybrid_kwargs["query_embedding"] is None + assert structured["search_mode"] == "fts_only" + assert structured["fallback_reason"] == "embed_error:RuntimeError" + assert [item["chunk_id"] for item in structured["results"]] == ["fts-fallback-hit"] + + +@pytest.mark.asyncio +async def test_brain_search_uses_hybrid_when_embedding_is_fast(monkeypatch): + store = FtsFallbackSearchStore() + + monkeypatch.setenv("BRAINLAYER_EMBED_TIMEOUT_MS", "1000") + monkeypatch.setattr("brainlayer.mcp.search_handler._get_vector_store", lambda: store) + monkeypatch.setattr("brainlayer.mcp.search_handler._get_embedding_model", lambda: FakeEmbeddingModel()) + monkeypatch.setattr("brainlayer.mcp.search_handler._normalize_project_name", lambda project: project) + + _content, structured = await _search( + query="keyword fallback", + source="all", + num_results=1, + ) + + assert store.hybrid_kwargs["query_embedding"] == [0.1, 0.2, 0.3] + assert structured["search_mode"] == "hybrid" + assert "fallback_reason" not in structured + assert [item["chunk_id"] for item in structured["results"]] == ["hybrid-hit"] + + @pytest.mark.asyncio async def test_brain_search_mcp_threads_agent_id_to_hybrid_search(monkeypatch): store = RecordingSearchStore()