diff --git a/apps/api/core/substitution.py b/apps/api/core/substitution.py index 70539c9..07ec0bc 100644 --- a/apps/api/core/substitution.py +++ b/apps/api/core/substitution.py @@ -64,6 +64,17 @@ def _supports_idempotency_key(slug: str) -> bool: return slug in _IDEMPOTENCY_KEY_PROVIDERS + +# Provider-specific params that CANNOT carry to a different provider in the group. +# For llm-inference, an explicitly-pinned `model` is a vendor model name that a +# substitute cannot honor. Policy (decided): SURFACE a clean error — never strip +# the param and never silently fall back to the substitute's default model. (A +# curated model-equivalence map is a separate follow-up.) Unpinned calls self-heal. +def _pinned_unhonorable_param(category: str | None, user_params: dict) -> tuple[str, str] | None: + if category == "llm-inference" and user_params.get("model"): + return ("model", str(user_params["model"])) + return None + # Per-category in-process cache of the ordered chain (TTL bounded). Keyed by # category; value = (expiry_monotonic, [ordered_slugs]). _CHAIN_CACHE: dict[str, tuple[float, list[str]]] = {} @@ -317,6 +328,23 @@ def _may_act(s: str) -> bool: origin_post_send = settlement == _SETTLEMENT_POST + # 3b. Pinned-param honesty: if the caller pinned a provider-specific param (an + # LLM model) that NO substitute in the group can honor, surface a clean + # error rather than silently serving a different model. The same-provider + # retry above already covered the honorable case; a cross-provider + # substitute would change the model under the caller's feet. + _pinned = _pinned_unhonorable_param(category, user_params) + if _pinned and chain: + field, val = _pinned + out.client_error = ( + f"primary '{primary_slug}' failed and your request pins {field}='{val}', " + f"which no substitute in the '{category}' group can honor. Omit '{field}' to " + f"self-heal across providers, or pin a model the substitute supports." + ) + out.providers_tried.append((primary_slug, f"pinned_{field}_unhonorable")) + out.settlement_class, out.retried_primary = settlement, retried + return out + # 4. Chain through the group, depth-capped. for candidate in chain[: policy.max_depth]: cfg = SERVICE_CONFIGS.get(candidate) diff --git a/apps/api/routers/execute.py b/apps/api/routers/execute.py index f27f519..4201461 100644 --- a/apps/api/routers/execute.py +++ b/apps/api/routers/execute.py @@ -1005,7 +1005,7 @@ async def pay_for_service(request: Request, db=Depends(get_db)): @router.post("/execute") @limiter.limit("60/minute") -async def execute_service(request: Request, db=Depends(get_db)): +async def execute_service(request: Request, response: Response, db=Depends(get_db)): """Call a real external API using Wayforth-managed keys or user BYOK keys.""" import time as _time @@ -1307,6 +1307,10 @@ async def execute_service(request: Request, db=Depends(get_db)): # ── Managed-catalog path ────────────────────────────────────────────────── config = SERVICE_CONFIGS[service_slug] + # Save the user's params BEFORE service defaults are injected, so the failover + # engine can re-map cleanly for a substitute without leaking the primary's + # default param values. + _user_params = dict(params) # Normalise params: resolve aliases, inject defaults, wrap prompt→messages params, _missing = map_params(service_slug, params) if _missing: @@ -1383,77 +1387,64 @@ async def execute_service(request: Request, db=Depends(get_db)): "top_up_url": "https://wayforth.io/billing", }) - start = _time.time() - adapter = ADAPTERS[service_slug] - result = None - error_msg = None - - if service_slug == "assemblyai": - try: - result = await asyncio.wait_for(adapter(params, svc_key), timeout=35.0) - except asyncio.TimeoutError: - error_msg = "Service timeout" - except Exception as e: - error_msg = str(e)[:300] - else: - for attempt in range(2): - try: - result = await asyncio.wait_for(adapter(params, svc_key), timeout=10.0) - break - except asyncio.TimeoutError: - if attempt == 0: - continue - error_msg = "Service timeout" - except Exception as e: - error_msg = str(e)[:300] - logger.warning("managed adapter error: %s attempt=%d error=%s", service_slug, attempt, error_msg) - break - - execution_ms = round((_time.time() - start) * 1000) + # Primary execution via the settlement-aware variant (gives pre/post-send + # classification the failover engine needs). + result, error_msg, execution_ms, _primary_settlement = await _try_execute_managed_ex( + service_slug, params, svc_key + ) _execute_fallback_from: str | None = None _original_failure_code: str | None = None if error_msg and _classify_error(error_msg) == "service_failure": _original_failure_code = _classify_failure(None, error_msg) - new_bal = await _do_refund(db, user_id, credit_cost, service_slug, error_msg, "/execute", balance_after, - _mk_refund_key(getattr(request.state, "request_id", ""), service_slug, "execute_managed")) - # Try one automatic fallback for managed-key calls - _fb_slug = SERVICE_ALTERNATIVES.get(service_slug) - if key_source == "managed" and _fb_slug and _fb_slug in SERVICE_CONFIGS: - _fb_cfg = SERVICE_CONFIGS[_fb_slug] - _fb_api_key = os.environ.get(_fb_cfg["key_var"], "") - if _fb_api_key: - _fb_mapped, _fb_miss = map_params(_fb_slug, params) - if not _fb_miss: - _fb_cost = _fb_cfg["credits"] - _fb_ok, _fb_bal, _fb_tx_id = await check_and_deduct_credits( - db, str(user_id), _fb_cost, "/execute", - service_id=_fb_slug, tx_type="execution", - agent_id=agent_id, api_key_id=str(_api_key_id), - return_tx_id=True, - ) - if _fb_ok: - result, _fb_err, execution_ms = await _try_execute_managed(_fb_slug, _fb_mapped, _fb_api_key) - if _fb_err and _classify_error(_fb_err) == "service_failure": - await _do_refund(db, user_id, _fb_cost, _fb_slug, _fb_err, "/execute", _fb_bal, - _mk_refund_key(getattr(request.state, "request_id", ""), _fb_slug, "execute_managed_fb")) - result = None - elif _fb_err: - raise HTTPException(status_code=400, detail={"error": _fb_err, "refunded": False, "credits_restored": 0}) - else: - _execute_fallback_from = service_slug - service_slug = _fb_slug - credit_cost = _fb_cost - balance_after = _fb_bal - # Point signal patch at the fallback tx row. - _tx_id = _fb_tx_id - if result is None: - # Patch the original tx with the failure code before raising. - from main import app as _app_ref + from main import app as _app_ref + if key_source == "managed": + # Multi-hop substitution/failover engine (same one /proxy uses). Charges + # only the served provider; never double-charges; surfaces cleanly when a + # pinned provider-specific param (LLM model) can't be honored by a substitute. + from core.substitution import run_with_failover + outcome = await run_with_failover( + db, pool=_app_ref.state.pool, + request_id=getattr(request.state, "request_id", ""), + user_id=str(user_id), api_key_id=str(_api_key_id), agent_id=agent_id, + primary_slug=service_slug, user_params=_user_params, + primary_error=error_msg, primary_settlement=_primary_settlement, + primary_cost=credit_cost, primary_balance_after=balance_after, + primary_tx_id=_tx_id, primary_svc_key=svc_key, rail="managed", + ) + if outcome.client_error: + raise HTTPException(status_code=400, detail={ + "error": outcome.client_error, "refunded": True, "credits_restored": credit_cost, + }) + if outcome.served_slug is None: + asyncio.create_task(_patch_tx_signals( + _app_ref.state.pool, _tx_id, + failure_code=outcome.original_failure_code, task_query_text=_preceding_query, + )) + raise HTTPException(status_code=502, detail={ + "error": "all_providers_failed", + "category": outcome.category, + "providers_tried": [{"provider": s, "reason": r} for s, r in outcome.providers_tried], + "refunded": True, + "credits_restored": credit_cost, + "credits_remaining": outcome.balance_after, + "calls_remaining": outcome.balance_after, # backward compat + }) + _execute_fallback_from = outcome.fallback_from + service_slug = outcome.served_slug + credit_cost = outcome.cost + balance_after = outcome.balance_after + _tx_id = outcome.tx_id + result = outcome.result + execution_ms = outcome.execution_ms + else: + # BYOK: the user's own key — no substitution. Refund and surface. + new_bal = await _do_refund( + db, user_id, credit_cost, service_slug, error_msg, "/execute", balance_after, + _mk_refund_key(getattr(request.state, "request_id", ""), service_slug, "execute_byok")) asyncio.create_task(_patch_tx_signals( _app_ref.state.pool, _tx_id, - failure_code=_original_failure_code, - task_query_text=_preceding_query, + failure_code=_original_failure_code, task_query_text=_preceding_query, )) raise HTTPException(status_code=503, detail={ "error": "Service unavailable", @@ -1501,9 +1492,15 @@ async def execute_service(request: Request, db=Depends(get_db)): substitution_reason=_original_failure_code if _execute_fallback_from else None, )) + # Visible self-heal surface (parity with /proxy). + response.headers["X-Wayforth-Served-By"] = service_slug + response.headers["X-Wayforth-Fallback"] = "true" if _execute_fallback_from else "false" + resp = { "status": "ok", "service": service_slug, + "served_by": service_slug, + "fallback": bool(_execute_fallback_from), "result": result, "credits_deducted": credit_cost, "execution_ms": execution_ms, diff --git a/apps/api/tests/test_substitution.py b/apps/api/tests/test_substitution.py index 1ce31e5..1ac4cef 100644 --- a/apps/api/tests/test_substitution.py +++ b/apps/api/tests/test_substitution.py @@ -281,6 +281,38 @@ def test_invalid_body_after_200_does_not_serve(harness): assert "brave" in [r[0] for r in harness["refunds"]] +# ── 6b. pinned-model honesty (surface, don't strip/default) ─────────────────── + +async def _run_llm(primary="groq", user_params=None, settlement=_SETTLEMENT_PRE, policy=None): + if policy is None: + policy = FailoverPolicy(retry_primary_on_transient=False) + return await run_with_failover( + db=None, pool=None, request_id="r", user_id="u", api_key_id="k", agent_id=None, + primary_slug=primary, user_params=user_params or {"messages": [{"role": "user", "content": "hi"}]}, + primary_error="Service timeout", primary_settlement=settlement, + primary_cost=3, primary_balance_after=997, primary_tx_id="tx", primary_svc_key="K", + rail="managed", policy=policy, + ) + + +def test_pinned_model_surfaces_not_substituted(harness): + harness["set_chain"]("llm-inference", ["mistral", "together"]) + harness["set_exec"]({"mistral": ({"content": "hi"}, None, 5, _SETTLEMENT_PRE)}) + out = asyncio.run(_run_llm(user_params={ + "messages": [{"role": "user", "content": "hi"}], "model": "llama-3.3-70b-versatile"})) + assert out.served_slug is None + assert out.client_error and "model" in out.client_error # clean, explanatory error + assert ("mistral", 4) not in harness["deducts"] # substitute NOT attempted (no silent default) + assert "groq" in [r[0] for r in harness["refunds"]] # primary still refunded + + +def test_unpinned_llm_self_heals(harness): + harness["set_chain"]("llm-inference", ["mistral"]) + harness["set_exec"]({"mistral": ({"content": "pong"}, None, 5, _SETTLEMENT_PRE)}) + out = asyncio.run(_run_llm(user_params={"messages": [{"role": "user", "content": "hi"}]})) + assert out.served_slug == "mistral" # unpinned → self-heals as before + + # ── 7. event row emitted per hop with settlement_class ──────────────────────── def test_events_emitted_per_hop(harness):