Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions apps/api/core/substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,17 @@
def _supports_idempotency_key(slug: str) -> bool:
return slug in _IDEMPOTENCY_KEY_PROVIDERS


# Provider-specific params that CANNOT carry to a different provider in the group.
# For llm-inference, an explicitly-pinned `model` is a vendor model name that a
# substitute cannot honor. Policy (decided): SURFACE a clean error — never strip
# the param and never silently fall back to the substitute's default model. (A
# curated model-equivalence map is a separate follow-up.) Unpinned calls self-heal.
def _pinned_unhonorable_param(category: str | None, user_params: dict) -> tuple[str, str] | None:
if category == "llm-inference" and user_params.get("model"):
return ("model", str(user_params["model"]))
return None

# Per-category in-process cache of the ordered chain (TTL bounded). Keyed by
# category; value = (expiry_monotonic, [ordered_slugs]).
_CHAIN_CACHE: dict[str, tuple[float, list[str]]] = {}
Expand Down Expand Up @@ -317,6 +328,23 @@ def _may_act(s: str) -> bool:

origin_post_send = settlement == _SETTLEMENT_POST

# 3b. Pinned-param honesty: if the caller pinned a provider-specific param (an
# LLM model) that NO substitute in the group can honor, surface a clean
# error rather than silently serving a different model. The same-provider
# retry above already covered the honorable case; a cross-provider
# substitute would change the model under the caller's feet.
_pinned = _pinned_unhonorable_param(category, user_params)
if _pinned and chain:
field, val = _pinned
out.client_error = (
f"primary '{primary_slug}' failed and your request pins {field}='{val}', "
f"which no substitute in the '{category}' group can honor. Omit '{field}' to "
f"self-heal across providers, or pin a model the substitute supports."
)
out.providers_tried.append((primary_slug, f"pinned_{field}_unhonorable"))
out.settlement_class, out.retried_primary = settlement, retried
return out

# 4. Chain through the group, depth-capped.
for candidate in chain[: policy.max_depth]:
cfg = SERVICE_CONFIGS.get(candidate)
Expand Down
127 changes: 62 additions & 65 deletions apps/api/routers/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,7 +1005,7 @@ async def pay_for_service(request: Request, db=Depends(get_db)):

@router.post("/execute")
@limiter.limit("60/minute")
async def execute_service(request: Request, db=Depends(get_db)):
async def execute_service(request: Request, response: Response, db=Depends(get_db)):
"""Call a real external API using Wayforth-managed keys or user BYOK keys."""
import time as _time

Expand Down Expand Up @@ -1307,6 +1307,10 @@ async def execute_service(request: Request, db=Depends(get_db)):
# ── Managed-catalog path ──────────────────────────────────────────────────
config = SERVICE_CONFIGS[service_slug]

# Save the user's params BEFORE service defaults are injected, so the failover
# engine can re-map cleanly for a substitute without leaking the primary's
# default param values.
_user_params = dict(params)
# Normalise params: resolve aliases, inject defaults, wrap prompt→messages
params, _missing = map_params(service_slug, params)
if _missing:
Expand Down Expand Up @@ -1383,77 +1387,64 @@ async def execute_service(request: Request, db=Depends(get_db)):
"top_up_url": "https://wayforth.io/billing",
})

start = _time.time()
adapter = ADAPTERS[service_slug]
result = None
error_msg = None

if service_slug == "assemblyai":
try:
result = await asyncio.wait_for(adapter(params, svc_key), timeout=35.0)
except asyncio.TimeoutError:
error_msg = "Service timeout"
except Exception as e:
error_msg = str(e)[:300]
else:
for attempt in range(2):
try:
result = await asyncio.wait_for(adapter(params, svc_key), timeout=10.0)
break
except asyncio.TimeoutError:
if attempt == 0:
continue
error_msg = "Service timeout"
except Exception as e:
error_msg = str(e)[:300]
logger.warning("managed adapter error: %s attempt=%d error=%s", service_slug, attempt, error_msg)
break

execution_ms = round((_time.time() - start) * 1000)
# Primary execution via the settlement-aware variant (gives pre/post-send
# classification the failover engine needs).
result, error_msg, execution_ms, _primary_settlement = await _try_execute_managed_ex(
service_slug, params, svc_key
)

_execute_fallback_from: str | None = None
_original_failure_code: str | None = None
if error_msg and _classify_error(error_msg) == "service_failure":
_original_failure_code = _classify_failure(None, error_msg)
new_bal = await _do_refund(db, user_id, credit_cost, service_slug, error_msg, "/execute", balance_after,
_mk_refund_key(getattr(request.state, "request_id", ""), service_slug, "execute_managed"))
# Try one automatic fallback for managed-key calls
_fb_slug = SERVICE_ALTERNATIVES.get(service_slug)
if key_source == "managed" and _fb_slug and _fb_slug in SERVICE_CONFIGS:
_fb_cfg = SERVICE_CONFIGS[_fb_slug]
_fb_api_key = os.environ.get(_fb_cfg["key_var"], "")
if _fb_api_key:
_fb_mapped, _fb_miss = map_params(_fb_slug, params)
if not _fb_miss:
_fb_cost = _fb_cfg["credits"]
_fb_ok, _fb_bal, _fb_tx_id = await check_and_deduct_credits(
db, str(user_id), _fb_cost, "/execute",
service_id=_fb_slug, tx_type="execution",
agent_id=agent_id, api_key_id=str(_api_key_id),
return_tx_id=True,
)
if _fb_ok:
result, _fb_err, execution_ms = await _try_execute_managed(_fb_slug, _fb_mapped, _fb_api_key)
if _fb_err and _classify_error(_fb_err) == "service_failure":
await _do_refund(db, user_id, _fb_cost, _fb_slug, _fb_err, "/execute", _fb_bal,
_mk_refund_key(getattr(request.state, "request_id", ""), _fb_slug, "execute_managed_fb"))
result = None
elif _fb_err:
raise HTTPException(status_code=400, detail={"error": _fb_err, "refunded": False, "credits_restored": 0})
else:
_execute_fallback_from = service_slug
service_slug = _fb_slug
credit_cost = _fb_cost
balance_after = _fb_bal
# Point signal patch at the fallback tx row.
_tx_id = _fb_tx_id
if result is None:
# Patch the original tx with the failure code before raising.
from main import app as _app_ref
from main import app as _app_ref
if key_source == "managed":
# Multi-hop substitution/failover engine (same one /proxy uses). Charges
# only the served provider; never double-charges; surfaces cleanly when a
# pinned provider-specific param (LLM model) can't be honored by a substitute.
from core.substitution import run_with_failover
outcome = await run_with_failover(
db, pool=_app_ref.state.pool,
request_id=getattr(request.state, "request_id", ""),
user_id=str(user_id), api_key_id=str(_api_key_id), agent_id=agent_id,
primary_slug=service_slug, user_params=_user_params,
primary_error=error_msg, primary_settlement=_primary_settlement,
primary_cost=credit_cost, primary_balance_after=balance_after,
primary_tx_id=_tx_id, primary_svc_key=svc_key, rail="managed",
)
if outcome.client_error:
raise HTTPException(status_code=400, detail={
"error": outcome.client_error, "refunded": True, "credits_restored": credit_cost,
})
if outcome.served_slug is None:
asyncio.create_task(_patch_tx_signals(
_app_ref.state.pool, _tx_id,
failure_code=outcome.original_failure_code, task_query_text=_preceding_query,
))
raise HTTPException(status_code=502, detail={
"error": "all_providers_failed",
"category": outcome.category,
"providers_tried": [{"provider": s, "reason": r} for s, r in outcome.providers_tried],
"refunded": True,
"credits_restored": credit_cost,
"credits_remaining": outcome.balance_after,
"calls_remaining": outcome.balance_after, # backward compat
})
_execute_fallback_from = outcome.fallback_from
service_slug = outcome.served_slug
credit_cost = outcome.cost
balance_after = outcome.balance_after
_tx_id = outcome.tx_id
result = outcome.result
execution_ms = outcome.execution_ms
else:
# BYOK: the user's own key — no substitution. Refund and surface.
new_bal = await _do_refund(
db, user_id, credit_cost, service_slug, error_msg, "/execute", balance_after,
_mk_refund_key(getattr(request.state, "request_id", ""), service_slug, "execute_byok"))
asyncio.create_task(_patch_tx_signals(
_app_ref.state.pool, _tx_id,
failure_code=_original_failure_code,
task_query_text=_preceding_query,
failure_code=_original_failure_code, task_query_text=_preceding_query,
))
raise HTTPException(status_code=503, detail={
"error": "Service unavailable",
Expand Down Expand Up @@ -1501,9 +1492,15 @@ async def execute_service(request: Request, db=Depends(get_db)):
substitution_reason=_original_failure_code if _execute_fallback_from else None,
))

# Visible self-heal surface (parity with /proxy).
response.headers["X-Wayforth-Served-By"] = service_slug
response.headers["X-Wayforth-Fallback"] = "true" if _execute_fallback_from else "false"

resp = {
"status": "ok",
"service": service_slug,
"served_by": service_slug,
"fallback": bool(_execute_fallback_from),
"result": result,
"credits_deducted": credit_cost,
"execution_ms": execution_ms,
Expand Down
32 changes: 32 additions & 0 deletions apps/api/tests/test_substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,38 @@ def test_invalid_body_after_200_does_not_serve(harness):
assert "brave" in [r[0] for r in harness["refunds"]]


# ── 6b. pinned-model honesty (surface, don't strip/default) ───────────────────

async def _run_llm(primary="groq", user_params=None, settlement=_SETTLEMENT_PRE, policy=None):
if policy is None:
policy = FailoverPolicy(retry_primary_on_transient=False)
return await run_with_failover(
db=None, pool=None, request_id="r", user_id="u", api_key_id="k", agent_id=None,
primary_slug=primary, user_params=user_params or {"messages": [{"role": "user", "content": "hi"}]},
primary_error="Service timeout", primary_settlement=settlement,
primary_cost=3, primary_balance_after=997, primary_tx_id="tx", primary_svc_key="K",
rail="managed", policy=policy,
)


def test_pinned_model_surfaces_not_substituted(harness):
harness["set_chain"]("llm-inference", ["mistral", "together"])
harness["set_exec"]({"mistral": ({"content": "hi"}, None, 5, _SETTLEMENT_PRE)})
out = asyncio.run(_run_llm(user_params={
"messages": [{"role": "user", "content": "hi"}], "model": "llama-3.3-70b-versatile"}))
assert out.served_slug is None
assert out.client_error and "model" in out.client_error # clean, explanatory error
assert ("mistral", 4) not in harness["deducts"] # substitute NOT attempted (no silent default)
assert "groq" in [r[0] for r in harness["refunds"]] # primary still refunded


def test_unpinned_llm_self_heals(harness):
harness["set_chain"]("llm-inference", ["mistral"])
harness["set_exec"]({"mistral": ({"content": "pong"}, None, 5, _SETTLEMENT_PRE)})
out = asyncio.run(_run_llm(user_params={"messages": [{"role": "user", "content": "hi"}]}))
assert out.served_slug == "mistral" # unpinned → self-heals as before


# ── 7. event row emitted per hop with settlement_class ────────────────────────

def test_events_emitted_per_hop(harness):
Expand Down
Loading