From 32e184a36919844a2005127c2cad77211987c714 Mon Sep 17 00:00:00 2001 From: djliden <7102904+djliden@users.noreply.github.com> Date: Mon, 8 Jun 2026 11:52:35 -0500 Subject: [PATCH 1/3] Migrate DAIS agents to Apps AI Gateway --- apps/caspers-ops-dashboard/app.yaml | 10 +- apps/caspers-ops-dashboard/app/main.py | 162 ++-- apps/complaint-agent/agent.py | 271 ++++++ apps/complaint-agent/app.yaml | 12 + apps/complaint-agent/requirements.txt | 8 + apps/complaint-agent/start_server.py | 46 + apps/refund-agent/agent.py | 235 +++++ apps/refund-agent/app.yaml | 12 + apps/refund-agent/requirements.txt | 10 + apps/refund-agent/start_server.py | 53 + databricks.yml | 26 +- jobs/complaint_agent_stream.ipynb | 143 +-- jobs/refund_recommender_stream.ipynb | 156 ++- stages/complaint_agent.ipynb | 1087 ++++++--------------- stages/complaint_agent_stream.ipynb | 24 +- stages/complaint_evaluation.ipynb | 156 ++- stages/operational_app.ipynb | 293 +++++- stages/refund_evaluation.ipynb | 132 +-- stages/refunder_agent.ipynb | 1240 ++++++------------------ stages/refunder_stream.ipynb | 50 +- utils/agent_app_client.py | 224 +++++ 21 files changed, 2173 insertions(+), 2177 deletions(-) create mode 100644 apps/complaint-agent/agent.py create mode 100644 apps/complaint-agent/app.yaml create mode 100644 apps/complaint-agent/requirements.txt create mode 100644 apps/complaint-agent/start_server.py create mode 100644 apps/refund-agent/agent.py create mode 100644 apps/refund-agent/app.yaml create mode 100644 apps/refund-agent/requirements.txt create mode 100644 apps/refund-agent/start_server.py create mode 100644 utils/agent_app_client.py diff --git a/apps/caspers-ops-dashboard/app.yaml b/apps/caspers-ops-dashboard/app.yaml index 1654d25..d29e1a3 100644 --- a/apps/caspers-ops-dashboard/app.yaml +++ b/apps/caspers-ops-dashboard/app.yaml @@ -47,10 +47,14 @@ env: # #page= fragment and the user sees the dashboard's default landing page. - name: OPS_DASHBOARD_PAGE value: '' - # Custom agent endpoints — populated by operational_app stage from CATALOG - - name: REFUND_AGENT_ENDPOINT + # Custom agent Apps — populated by operational_app stage from CATALOG + - name: REFUND_AGENT_APP_NAME value: '' - - name: COMPLAINT_AGENT_ENDPOINT + - name: REFUND_AGENT_APP_URL + value: '' + - name: COMPLAINT_AGENT_APP_NAME + value: '' + - name: COMPLAINT_AGENT_APP_URL value: '' - name: REFUND_MANAGER_APP_URL value: '' diff --git a/apps/caspers-ops-dashboard/app/main.py b/apps/caspers-ops-dashboard/app/main.py index 9ccffae..d954144 100644 --- a/apps/caspers-ops-dashboard/app/main.py +++ b/apps/caspers-ops-dashboard/app/main.py @@ -50,8 +50,10 @@ SUPERVISOR_ENDPOINT = os.environ.get("SUPERVISOR_ENDPOINT", "") SUPERVISOR_TILE_ID = os.environ.get("SUPERVISOR_TILE_ID", "") # written by operational_lakebase stage SUPERVISOR_MLFLOW_EXP_ID = os.environ.get("SUPERVISOR_MLFLOW_EXPERIMENT_ID", "") # written by operational_lakebase stage -REFUND_AGENT_ENDPOINT = os.environ.get("REFUND_AGENT_ENDPOINT", "") -COMPLAINT_AGENT_ENDPOINT = os.environ.get("COMPLAINT_AGENT_ENDPOINT", "") +REFUND_AGENT_APP_NAME = os.environ.get("REFUND_AGENT_APP_NAME", "") +REFUND_AGENT_APP_URL = os.environ.get("REFUND_AGENT_APP_URL", "") +COMPLAINT_AGENT_APP_NAME = os.environ.get("COMPLAINT_AGENT_APP_NAME", "") +COMPLAINT_AGENT_APP_URL = os.environ.get("COMPLAINT_AGENT_APP_URL", "") REFUND_MANAGER_APP_URL = os.environ.get("REFUND_MANAGER_APP_URL", "") SUPPORT_CONSOLE_APP_URL = os.environ.get("SUPPORT_CONSOLE_APP_URL", "") LAKEBASE_INSTANCE = os.environ.get("LAKEBASE_ENDPOINT_PATH", "") # non-empty = DB enabled @@ -315,8 +317,10 @@ def _embed_url(d_id: str) -> str: "warehouse_id": WAREHOUSE_ID, "db_enabled": bool(LAKEBASE_INSTANCE), "supervisor_enabled": bool(SUPERVISOR_ENDPOINT), - "refund_agent_endpoint": REFUND_AGENT_ENDPOINT, - "complaint_agent_endpoint": COMPLAINT_AGENT_ENDPOINT, + "refund_agent_app_name": REFUND_AGENT_APP_NAME, + "refund_agent_app_url": REFUND_AGENT_APP_URL, + "complaint_agent_app_name": COMPLAINT_AGENT_APP_NAME, + "complaint_agent_app_url": COMPLAINT_AGENT_APP_URL, "refund_manager_app_url": REFUND_MANAGER_APP_URL, "support_console_app_url": SUPPORT_CONSOLE_APP_URL, "mlflow_experiment_id": mlflow_experiment_id, @@ -1047,9 +1051,23 @@ class ComplaintRequest(BaseModel): order_id: str = "" -def _call_agent_endpoint(endpoint_name: str, payload: dict) -> dict: - """Call a model serving endpoint (ChatAgent or ResponsesAgent) and return the parsed response body.""" - url = f"{(_sdk_config.host or '').rstrip('/')}/serving-endpoints/{endpoint_name}/invocations" +def _agent_app_url(app_name: str, configured_url: str) -> str: + if configured_url: + return configured_url.rstrip("/") + if app_name: + app_info = _ws.apps.get(app_name) + url = getattr(app_info, "url", "") or "" + if url: + return url.rstrip("/") + return "" + + +def _call_agent_app(app_name: str, configured_url: str, payload: dict) -> dict: + """Call a DAIS custom agent Databricks App via MLflow AgentServer /responses.""" + base_url = _agent_app_url(app_name, configured_url) + if not base_url: + raise HTTPException(status_code=503, detail=f"Agent app {app_name or '(unknown)'} not configured.") + url = f"{base_url}/responses" headers = {"Content-Type": "application/json"} headers.update(_sdk_config.authenticate()) resp = httpx.post(url, headers=headers, json=payload, timeout=120.0) @@ -1057,6 +1075,35 @@ def _call_agent_endpoint(endpoint_name: str, payload: dict) -> dict: return resp.json() +def _extract_agent_output_text(data: dict) -> str: + output_text = data.get("output_text") + if isinstance(output_text, str) and output_text: + return output_text + + output = data.get("output") or [] + if isinstance(output, dict): + output = [output] + for out in output: + if not isinstance(out, dict): + continue + content = out.get("content") + if isinstance(content, str) and content: + return content + if isinstance(content, dict): + content = [content] + if isinstance(content, list): + for part in content: + if isinstance(part, dict): + text = part.get("text") + if isinstance(text, str) and text: + return text + + choices = data.get("choices") or [] + if choices: + return (choices[0].get("message") or {}).get("content", "") + return "" + + def _build_refund_user_message(req: "RefundRequest") -> str: """Compose the user message sent to the refund agent. @@ -1084,52 +1131,42 @@ def _build_refund_user_message(req: "RefundRequest") -> str: @app.post("/api/refund") def refund(req: RefundRequest): """Call the refund agent for a given order_id and return a structured decision.""" - if not REFUND_AGENT_ENDPOINT: - raise HTTPException(status_code=503, detail="Refund agent endpoint not configured.") + if not (REFUND_AGENT_APP_NAME or REFUND_AGENT_APP_URL): + raise HTTPException(status_code=503, detail="Refund agent app not configured.") try: user_msg = _build_refund_user_message(req) - data = _call_agent_endpoint( - REFUND_AGENT_ENDPOINT, - {"messages": [{"role": "user", "content": user_msg}]}, + data = _call_agent_app( + REFUND_AGENT_APP_NAME, + REFUND_AGENT_APP_URL, + {"input": [{"role": "user", "content": user_msg}]}, ) - # Extract the last assistant message from the ChatAgent response - messages = data.get("messages") or [] - for msg in reversed(messages): - role = msg.get("role", "") - content = msg.get("content", "") - if role == "assistant" and content: - # Robustly extract a JSON object from the assistant message: - # the agent's prompt asks for raw JSON, but LLMs often wrap it in - # ```json … ``` fences or sprinkle commentary around it. Try a - # bare json.loads first, then fall back to the first {...} match. - cleaned = content.strip() - if cleaned.startswith("```"): - # strip markdown code fence (```json … ``` or ``` … ```) - cleaned = cleaned.strip("`") - if cleaned.lower().startswith("json"): - cleaned = cleaned[4:] - cleaned = cleaned.strip() - decision = None + content = _extract_agent_output_text(data) + if not content: + raise HTTPException(status_code=502, detail="No output in refund agent response.") + cleaned = content.strip() + if cleaned.startswith("```"): + cleaned = cleaned.strip("`") + if cleaned.lower().startswith("json"): + cleaned = cleaned[4:] + cleaned = cleaned.strip() + decision = None + try: + decision = json.loads(cleaned) + except Exception: + m = re.search(r"\{[\s\S]*\}", cleaned) + if m: try: - decision = json.loads(cleaned) + decision = json.loads(m.group(0)) except Exception: - import re as _re - m = _re.search(r"\{[\s\S]*\}", cleaned) - if m: - try: - decision = json.loads(m.group(0)) - except Exception: - decision = None - if decision is not None: - return { - "order_id": req.order_id, - "refund_usd": float(decision.get("refund_usd", 0)), - "refund_class": decision.get("refund_class", "none"), - "reason": decision.get("reason", ""), - } - # Last-resort fallback: surface raw text so the UI can show something. - return {"order_id": req.order_id, "raw": content} - raise HTTPException(status_code=502, detail="No assistant message in refund agent response.") + decision = None + if decision is not None: + return { + "order_id": req.order_id, + "refund_usd": float(decision.get("refund_usd", 0)), + "refund_class": decision.get("refund_class", "none"), + "reason": decision.get("reason", ""), + } + return {"order_id": req.order_id, "raw": content} except HTTPException: raise except Exception as e: @@ -1140,25 +1177,18 @@ def refund(req: RefundRequest): @app.post("/api/complaint") def complaint(req: ComplaintRequest): """Call the complaint agent for a raw complaint text and return a structured classification.""" - if not COMPLAINT_AGENT_ENDPOINT: - raise HTTPException(status_code=503, detail="Complaint agent endpoint not configured.") + if not (COMPLAINT_AGENT_APP_NAME or COMPLAINT_AGENT_APP_URL): + raise HTTPException(status_code=503, detail="Complaint agent app not configured.") content = req.complaint_text if req.order_id: content = f"{content} (Order ID: {req.order_id})" try: - data = _call_agent_endpoint( - COMPLAINT_AGENT_ENDPOINT, + data = _call_agent_app( + COMPLAINT_AGENT_APP_NAME, + COMPLAINT_AGENT_APP_URL, {"input": [{"role": "user", "content": content}]}, ) - # ResponsesAgent returns output list or choices - output_text = "" - for out in data.get("output", []): - for part in (out.get("content") or []): - output_text += part.get("text", "") - if not output_text: - choices = data.get("choices") or [] - if choices: - output_text = (choices[0].get("message") or {}).get("content", "") + output_text = _extract_agent_output_text(data) if output_text: try: result = json.loads(output_text) @@ -1276,14 +1306,14 @@ def list_agents(): "url": url, "id": tile_id}) _CUSTOM_AGENTS = [ - {"name": "Refund Agent", "icon": "💳", "endpoint": REFUND_AGENT_ENDPOINT}, - {"name": "Complaint Agent", "icon": "📬", "endpoint": COMPLAINT_AGENT_ENDPOINT}, + {"name": "Refund Agent", "icon": "💳", "app_name": REFUND_AGENT_APP_NAME, "app_url": REFUND_AGENT_APP_URL}, + {"name": "Complaint Agent", "icon": "📬", "app_name": COMPLAINT_AGENT_APP_NAME, "app_url": COMPLAINT_AGENT_APP_URL}, ] for ca in _CUSTOM_AGENTS: - ep = ca["endpoint"] - url = f"{host}/ml/endpoints/{ep}" if ep and host else "" + app_name = ca["app_name"] + url = ca["app_url"] or (f"{host}/apps/{app_name}" if app_name and host else "") agents.append({"name": ca["name"], "icon": ca["icon"], "type": "agent", - "url": url, "id": ep}) + "url": url, "id": app_name}) return agents diff --git a/apps/complaint-agent/agent.py b/apps/complaint-agent/agent.py new file mode 100644 index 0000000..94c4671 --- /dev/null +++ b/apps/complaint-agent/agent.py @@ -0,0 +1,271 @@ +import os +import uuid +import warnings +from typing import Literal, Optional + +import dspy +import mlflow +from databricks.sdk.core import Config +from mlflow.genai.agent_server import invoke +from mlflow.types.responses import ResponsesAgentRequest, ResponsesAgentResponse +from openai import OpenAI +from pydantic import BaseModel, Field, ValidationError, field_validator +from unitycatalog.ai.core.base import get_uc_function_client + + +warnings.filterwarnings("ignore", message=".*Ignoring the default notebook Spark session.*") + +mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "databricks")) +mlflow.set_registry_uri(os.getenv("MLFLOW_REGISTRY_URI", "databricks-uc")) +mlflow.dspy.autolog(log_traces=True) + +CATALOG = os.environ["DATABRICKS_CATALOG"] +LLM_MODEL = os.environ["LLM_MODEL"] +HOST = (os.environ.get("DATABRICKS_HOST") or Config().host).rstrip("/") +GATEWAY_BASE_URL = f"{HOST}/ai-gateway/mlflow/v1" + + +def _auth_header() -> str: + header = Config().authenticate().get("Authorization", "") + if not header.startswith("Bearer "): + raise RuntimeError("Databricks OAuth bearer token is unavailable") + return header + + +def _token() -> str: + return _auth_header().removeprefix("Bearer ") + + +def _configure_dspy() -> None: + lm = dspy.LM( + f"openai/{LLM_MODEL}", + api_base=GATEWAY_BASE_URL, + api_key=_token(), + max_tokens=2000, + num_retries=20, + cache=False, + ) + dspy.configure(lm=lm) + + +def _validate_gateway_endpoint() -> None: + client = OpenAI(api_key=_token(), base_url=GATEWAY_BASE_URL, timeout=30) + client.chat.completions.create( + model=LLM_MODEL, + messages=[{"role": "user", "content": "Say gateway ok."}], + max_tokens=8, + ) + + +_validate_gateway_endpoint() +_uc_client = None + + +def _client(): + global _uc_client + if _uc_client is None: + _uc_client = get_uc_function_client() + return _uc_client + + +class ComplaintResponse(BaseModel): + """Structured output for complaint triage decisions.""" + + order_id: str + complaint_category: Literal[ + "delivery_delay", + "missing_items", + "food_quality", + "service_issue", + "billing", + "other", + ] = Field(description="Exactly ONE primary complaint category") + decision: Literal["suggest_credit", "escalate"] + credit_amount: Optional[float] = None + confidence: Optional[Literal["high", "medium", "low"]] = None + priority: Optional[Literal["standard", "urgent"]] = None + rationale: str + + @field_validator("complaint_category", mode="before") + @classmethod + def parse_category(cls, v): + if not isinstance(v, str): + return v + valid = [ + "delivery_delay", + "missing_items", + "food_quality", + "service_issue", + "billing", + "other", + ] + v_lower = v.lower().strip() + if v_lower in valid: + return v_lower + for cat in valid: + if cat in v_lower: + return cat + return "other" + + @field_validator("confidence", mode="before") + @classmethod + def parse_confidence(cls, v): + if v is None or (isinstance(v, str) and v.lower() == "null"): + return None + if isinstance(v, str): + v_lower = v.lower().strip() + return v_lower if v_lower in ["high", "medium", "low"] else "medium" + return v + + @field_validator("priority", mode="before") + @classmethod + def parse_priority(cls, v): + if v is None or (isinstance(v, str) and v.lower() == "null"): + return None + if isinstance(v, str): + v_lower = v.lower().strip() + return v_lower if v_lower in ["standard", "urgent"] else "standard" + return v + + +class ComplaintTriage(dspy.Signature): + """Analyze customer complaints for Casper's Kitchens and recommend triage actions. + + Process: + 1. Extract order_id from complaint + 2. Use get_order_overview(order_id) for order details and items + 3. Use get_order_timing(order_id) for delivery timing + 4. For delays, use get_location_timings(location) for percentile benchmarks + 5. Make data-backed decision + + Decision Framework: + + SUGGEST_CREDIT (with credit_amount and confidence): + - Delivery delays: Compare actual delivery time to location percentiles + * P99: Suggest 25% of order total (high confidence) + - Missing items: Use actual item prices from order data when available + * Verify claimed item exists in order (affects confidence) + * Use real costs from order data, or estimate $8-12 per item if unavailable + - Food quality: 20-40% of order total based on severity + * Minor issues (slightly cold, minor preparation issue): 20% (medium confidence) + * Major issues (completely inedible, wrong preparation, health concern): 40% (high confidence) + * Vague complaints ("bad", "gross"): escalate instead + + ESCALATE (with priority): + - priority="standard": Vague complaints, missing data, billing issues, service complaints + - priority="urgent": Legal threats, health/safety concerns, suspected fraud, abusive language + + Output Requirements: + - For suggest_credit: credit_amount is REQUIRED and must be a number (can be 0.0 if no credit warranted), confidence is REQUIRED, priority must be null + - For escalate: priority is REQUIRED, credit_amount and confidence must be null + - complaint_category: Choose EXACTLY ONE category (the primary one) + - Rationale must cite specific evidence (delivery times, percentiles, item verification, order total) + - Rationale should be detailed but under 150 words + - Round credit amounts to nearest $0.50 + - Confidence: high (strong data), medium (reasonable inference), low (weak/contradictory) + """ + + complaint: str = dspy.InputField(desc="Customer complaint text") + order_id: str = dspy.OutputField(desc="Extracted order ID") + complaint_category: str = dspy.OutputField( + desc="EXACTLY ONE category: delivery_delay, missing_items, food_quality, service_issue, billing, or other" + ) + decision: str = dspy.OutputField(desc="EXACTLY ONE: suggest_credit or escalate") + credit_amount: str = dspy.OutputField(desc="If suggest_credit: a number. If escalate: null") + confidence: str = dspy.OutputField(desc="If suggest_credit: high, medium, or low. If escalate: null") + priority: str = dspy.OutputField(desc="If escalate: standard or urgent. If suggest_credit: null") + rationale: str = dspy.OutputField(desc="Data-focused justification citing specific evidence") + + +def get_order_overview(order_id: str) -> str: + """Get order details including items, location, and customer info.""" + result = _client().execute_function(f"{CATALOG}.ai.get_order_overview", {"oid": order_id}) + return str(result.value) + + +def get_order_timing(order_id: str) -> str: + """Get timing information for a specific order.""" + result = _client().execute_function(f"{CATALOG}.ai.get_order_timing", {"oid": order_id}) + return str(result.value) + + +def get_location_timings(location: str) -> str: + """Get delivery time percentiles for a specific location.""" + result = _client().execute_function(f"{CATALOG}.ai.get_location_timings", {"loc": location}) + return str(result.value) + + +class ComplaintTriageModule(dspy.Module): + def __init__(self): + super().__init__() + self.react = dspy.ReAct( + signature=ComplaintTriage, + tools=[get_order_overview, get_order_timing, get_location_timings], + max_iters=10, + ) + + def forward(self, complaint: str, max_retries: int = 2) -> ComplaintResponse: + for attempt in range(max_retries + 1): + try: + result = self.react(complaint=complaint) + credit_amount = None + if result.credit_amount and result.credit_amount.lower() != "null": + try: + credit_amount = float(result.credit_amount) + except (ValueError, TypeError): + credit_amount = None + if result.decision == "suggest_credit" and credit_amount is None: + credit_amount = 0.0 + return ComplaintResponse( + order_id=result.order_id, + complaint_category=result.complaint_category, + decision=result.decision, + credit_amount=credit_amount, + confidence=result.confidence, + priority=result.priority, + rationale=result.rationale, + ) + except (ValidationError, ValueError): + if attempt >= max_retries: + raise + raise RuntimeError("Complaint triage failed after retries") + + +def _msg_to_dict(msg) -> dict: + if isinstance(msg, dict): + return msg + if hasattr(msg, "model_dump"): + return msg.model_dump() + if hasattr(msg, "dict"): + return msg.dict() + raise TypeError(f"Unsupported message type: {type(msg).__name__}") + + +def _text_output(text: str, item_id: str | None = None) -> dict: + return { + "id": item_id or str(uuid.uuid4()), + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + + +@invoke() +def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + _configure_dspy() + complaint = None + for msg in request.input: + msg_dict = _msg_to_dict(msg) + if msg_dict.get("role") == "user": + complaint = msg_dict.get("content", "") + break + if not complaint: + raise ValueError("No user message found in request") + + result = ComplaintTriageModule()(complaint=complaint) + return ResponsesAgentResponse( + output=[_text_output(result.model_dump_json())], + custom_outputs=request.custom_inputs, + ) diff --git a/apps/complaint-agent/app.yaml b/apps/complaint-agent/app.yaml new file mode 100644 index 0000000..ec46f8c --- /dev/null +++ b/apps/complaint-agent/app.yaml @@ -0,0 +1,12 @@ +command: + - python + - start_server.py +env: + - name: DATABRICKS_CATALOG + value: caspersdev + - name: LLM_MODEL + value: databricks-claude-sonnet-4-5 + - name: MLFLOW_TRACKING_URI + value: databricks + - name: MLFLOW_REGISTRY_URI + value: databricks-uc diff --git a/apps/complaint-agent/requirements.txt b/apps/complaint-agent/requirements.txt new file mode 100644 index 0000000..cad5c8c --- /dev/null +++ b/apps/complaint-agent/requirements.txt @@ -0,0 +1,8 @@ +mlflow[databricks]>=3.6.0 +databricks-sdk>=0.81.0 +dspy-ai +unitycatalog-openai[databricks] +openai +pydantic>=2 +fastapi +uvicorn[standard] diff --git a/apps/complaint-agent/start_server.py b/apps/complaint-agent/start_server.py new file mode 100644 index 0000000..5bc2591 --- /dev/null +++ b/apps/complaint-agent/start_server.py @@ -0,0 +1,46 @@ +import inspect +import os + +import agent # noqa: F401 - registers @invoke with MLflow AgentServer +from fastapi import Request +from fastapi.responses import JSONResponse +from mlflow.genai.agent_server import AgentServer, setup_mlflow_git_based_version_tracking +from mlflow.types.responses import ResponsesAgentRequest + + +agent_server = AgentServer("ResponsesAgent") +app = agent_server.app + +setup_mlflow_git_based_version_tracking() + + +@app.post("/responses") +@app.post("/api/responses") +async def responses(request: Request): + from mlflow.genai.agent_server import get_invoke_function + + body = await request.json() + if body.get("stream"): + return JSONResponse( + status_code=400, + content={"error": "stream=true is not supported by this agent app"}, + ) + + invoke_fn = get_invoke_function() + result = invoke_fn(ResponsesAgentRequest(**body)) + if inspect.isawaitable(result): + result = await result + return JSONResponse(content=result.model_dump(mode="json")) + + +def main(): + port = int(os.environ.get("DATABRICKS_APP_PORT", "8000")) + agent_server.run( + app_import_string="start_server:app", + host="0.0.0.0", + port=port, + ) + + +if __name__ == "__main__": + main() diff --git a/apps/refund-agent/agent.py b/apps/refund-agent/agent.py new file mode 100644 index 0000000..64b152a --- /dev/null +++ b/apps/refund-agent/agent.py @@ -0,0 +1,235 @@ +import json +import os +import uuid +from typing import Any, Literal, Optional, Sequence, Union + +import mlflow +from databricks.sdk.core import Config +from databricks_langchain import ChatDatabricks +from langchain_core.language_models import LanguageModelLike +from langchain_core.runnables import RunnableConfig, RunnableLambda +from langchain_core.tools import BaseTool, tool +from langgraph.graph import END, StateGraph +try: + from langgraph.graph.graph import CompiledGraph + from langgraph.graph.state import CompiledStateGraph +except ImportError: + CompiledGraph = Any + CompiledStateGraph = Any +from mlflow.genai.agent_server import invoke +from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode +from mlflow.types.responses import ResponsesAgentRequest, ResponsesAgentResponse +from openai import OpenAI +from pydantic import BaseModel, ValidationError +from unitycatalog.ai.core.base import get_uc_function_client + + +mlflow.set_tracking_uri(os.getenv("MLFLOW_TRACKING_URI", "databricks")) +mlflow.set_registry_uri(os.getenv("MLFLOW_REGISTRY_URI", "databricks-uc")) +mlflow.langchain.autolog() + +CATALOG = os.environ["DATABRICKS_CATALOG"] +LLM_MODEL = os.environ["LLM_MODEL"] +PROMPT_URI = f"prompts:/{CATALOG}.prompts.refund_system@production" + +_FALLBACK_PROMPT = """You are RefundGPT, a CX agent responsible for refund decisions on food delivery orders. + + You can call tools to gather the information you need. Start with an `order_id`. + + Instructions: + 1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered. + 2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`. + 3. Extract the location (either directly or from the first event's body). + 4. Call `get_location_timings(location)` to get the P50/P75/P99 values. + 5. Compare actual delivery time to those percentiles. + + Refund policy: + + A) SLA-based refund (primary path): + - If the order arrived AFTER the P75 delivery time: recommend a `partial` or `full` refund based on how late. + - If the order arrived BEFORE the P75: no SLA-based refund. + + B) Goodwill credit (only when complaint context is provided in the user message): + The user may include lines such as: + Customer complaint: "" + Complaint category: + Complaint agent suggested credit: $ + When all three are present AND the SLA path returns "none", you MAY ratify the + complaint agent's goodwill credit: + - Set `refund_class` = "partial" + - Set `refund_usd` to the suggested credit amount (capped at $10) + - In `reason`, note that the order was on time per SLA but a goodwill credit + is being issued in response to the customer's complaint (cite the category). + Only ratify when the suggested credit is plausible (>$0 and <=$10) and the + complaint category is non-empty. Otherwise return "none" with an SLA-based reason. + + When NO complaint context is provided, behave exactly as the SLA-based path (A) - + do not invent goodwill credits. + + Output a single-line JSON with these fields: + - `refund_usd` (float), + - `refund_class` ("none" | "partial" | "full"), + - `reason` (short human explanation. If goodwill, say so explicitly.) + + You must return only the JSON. No extra text or markdown.""" + + +def _auth_header() -> str: + header = Config().authenticate().get("Authorization", "") + if not header.startswith("Bearer "): + raise RuntimeError("Databricks OAuth bearer token is unavailable") + return header + + +def _validate_gateway_endpoint() -> None: + host = (os.environ.get("DATABRICKS_HOST") or Config().host).rstrip("/") + client = OpenAI( + api_key=_auth_header().removeprefix("Bearer "), + base_url=f"{host}/ai-gateway/mlflow/v1", + timeout=30, + ) + client.chat.completions.create( + model=LLM_MODEL, + messages=[{"role": "user", "content": "Say gateway ok."}], + max_tokens=8, + ) + + +_validate_gateway_endpoint() + +try: + SYSTEM_PROMPT = mlflow.genai.load_prompt(PROMPT_URI).template +except Exception as exc: + print(f"[refund-agent] Could not load {PROMPT_URI}: {type(exc).__name__}: {exc}") + SYSTEM_PROMPT = _FALLBACK_PROMPT + +_uc_client = None + + +def _client(): + global _uc_client + if _uc_client is None: + _uc_client = get_uc_function_client() + return _uc_client + + +class RefundDecision(BaseModel): + refund_usd: float = 0.0 + refund_class: Literal["none", "partial", "full"] = "none" + reason: str = "" + + +@tool +def get_order_details(order_id: str) -> str: + """Get the full event history for an order.""" + return str( + _client() + .execute_function(f"{CATALOG}.ai.get_order_details", {"oid": order_id}) + .value + ) + + +@tool +def get_order_delivery_time(order_id: str) -> str: + """Return creation timestamp, delivered timestamp, and delivery duration.""" + return str( + _client() + .execute_function(f"{CATALOG}.ai.get_order_delivery_time", {"oid": order_id}) + .value + ) + + +@tool +def get_location_timings(location: str) -> str: + """Return P50/P75/P99 delivery time percentiles for a kitchen location.""" + return str( + _client() + .execute_function(f"{CATALOG}.ai.get_location_timings", {"loc": location}) + .value + ) + + +TOOLS = [get_order_details, get_order_delivery_time, get_location_timings] + + +def create_tool_calling_agent( + model: LanguageModelLike, + tools: Union[Sequence[BaseTool], ChatAgentToolNode], + system_prompt: Optional[str] = None, +) -> CompiledGraph: + model = model.bind_tools(tools) + + def should_continue(state: ChatAgentState): + messages = state["messages"] + last_message = messages[-1] + return "continue" if last_message.get("tool_calls") else "end" + + if system_prompt: + preprocessor = RunnableLambda( + lambda state: [{"role": "system", "content": system_prompt}] + state["messages"] + ) + else: + preprocessor = RunnableLambda(lambda state: state["messages"]) + model_runnable = preprocessor | model + + def call_model(state: ChatAgentState, config: RunnableConfig): + return {"messages": [model_runnable.invoke(state, config)]} + + workflow = StateGraph(ChatAgentState) + workflow.add_node("agent", RunnableLambda(call_model)) + workflow.add_node("tools", ChatAgentToolNode(tools)) + workflow.set_entry_point("agent") + workflow.add_conditional_edges("agent", should_continue, {"continue": "tools", "end": END}) + workflow.add_edge("tools", "agent") + return workflow.compile() + + +LLM = ChatDatabricks(model=LLM_MODEL, use_ai_gateway=True) +AGENT: CompiledStateGraph = create_tool_calling_agent(LLM, TOOLS, SYSTEM_PROMPT) + + +def _msg_to_dict(msg) -> dict: + if isinstance(msg, dict): + return msg + if hasattr(msg, "model_dump"): + return msg.model_dump() + if hasattr(msg, "dict"): + return msg.dict() + raise TypeError(f"Unsupported message type: {type(msg).__name__}") + + +def _text_output(text: str, item_id: str | None = None) -> dict: + return { + "id": item_id or str(uuid.uuid4()), + "type": "message", + "role": "assistant", + "content": [{"type": "output_text", "text": text}], + } + + +def _run_agent(messages: list[dict]) -> str: + result_messages = [] + for event in AGENT.stream({"messages": messages}, stream_mode="updates"): + for node_data in event.values(): + result_messages.extend(node_data.get("messages", [])) + + for msg in reversed(result_messages): + role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None) + content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "") + if role == "assistant" and content: + try: + parsed = RefundDecision.model_validate_json(content) + return parsed.model_dump_json() + except (ValidationError, ValueError, TypeError): + return str(content) + raise RuntimeError("Refund agent produced no assistant message") + + +@invoke() +def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + messages = [_msg_to_dict(msg) for msg in request.input] + text = _run_agent(messages) + return ResponsesAgentResponse( + output=[_text_output(text)], + custom_outputs=request.custom_inputs, + ) diff --git a/apps/refund-agent/app.yaml b/apps/refund-agent/app.yaml new file mode 100644 index 0000000..ec46f8c --- /dev/null +++ b/apps/refund-agent/app.yaml @@ -0,0 +1,12 @@ +command: + - python + - start_server.py +env: + - name: DATABRICKS_CATALOG + value: caspersdev + - name: LLM_MODEL + value: databricks-claude-sonnet-4-5 + - name: MLFLOW_TRACKING_URI + value: databricks + - name: MLFLOW_REGISTRY_URI + value: databricks-uc diff --git a/apps/refund-agent/requirements.txt b/apps/refund-agent/requirements.txt new file mode 100644 index 0000000..4c24d1d --- /dev/null +++ b/apps/refund-agent/requirements.txt @@ -0,0 +1,10 @@ +mlflow[databricks]>=3.6.0 +databricks-sdk>=0.81.0 +databricks-langchain +langgraph>=0.3.5,<0.4.0 +langchain-core +unitycatalog-ai[databricks] +openai +pydantic>=2 +fastapi +uvicorn[standard] diff --git a/apps/refund-agent/start_server.py b/apps/refund-agent/start_server.py new file mode 100644 index 0000000..c0dc438 --- /dev/null +++ b/apps/refund-agent/start_server.py @@ -0,0 +1,53 @@ +import inspect +import json +import os + +import agent # noqa: F401 - registers @invoke with MLflow AgentServer +from fastapi import Request +from fastapi.responses import JSONResponse +from mlflow.genai.agent_server import AgentServer, setup_mlflow_git_based_version_tracking +from mlflow.types.responses import ResponsesAgentRequest + + +agent_server = AgentServer("ResponsesAgent") +app = agent_server.app + +setup_mlflow_git_based_version_tracking() + + +@app.post("/responses") +@app.post("/api/responses") +async def responses(request: Request): + """Databricks Apps agent-compatible Responses API alias. + + MLflow AgentServer serves /invocations locally. Databricks Apps agent + clients use /responses, so expose the same registered invoke function on + that route too. + """ + from mlflow.genai.agent_server import get_invoke_function + + body = await request.json() + if body.get("stream"): + return JSONResponse( + status_code=400, + content={"error": "stream=true is not supported by this agent app"}, + ) + + invoke_fn = get_invoke_function() + result = invoke_fn(ResponsesAgentRequest(**body)) + if inspect.isawaitable(result): + result = await result + return JSONResponse(content=result.model_dump(mode="json")) + + +def main(): + port = int(os.environ.get("DATABRICKS_APP_PORT", "8000")) + agent_server.run( + app_import_string="start_server:app", + host="0.0.0.0", + port=port, + ) + + +if __name__ == "__main__": + main() diff --git a/databricks.yml b/databricks.yml index 2ce3c9c..d94aa22 100644 --- a/databricks.yml +++ b/databricks.yml @@ -145,8 +145,6 @@ targets: # --params "CATALOG=oleksandra,SKIP_EVAL=false" - name: SKIP_EVAL default: "true" - - name: REFUND_AGENT_ENDPOINT_NAME - default: ${var.catalog}_refund_agent - name: SIMULATOR_SCHEMA default: simulator - name: START_DAY @@ -209,11 +207,10 @@ targets: notebook_path: ${workspace.root_path}/stages/refund_evaluation # Refund_Recommender_Stream calls w.jobs.run_now() on the - # streaming inference job whose UDF queries the Refund agent - # serving endpoint. The endpoint must exist and be READY before - # the stream's first batch fires — without the explicit - # Refund_Recommender_Agent edge the two tasks race and the - # stream can hit a missing endpoint. + # streaming inference job whose UDF queries the Refund agent App. + # The App must exist before the stream's first batch fires — + # without the explicit Refund_Recommender_Agent edge the two tasks + # race and the stream can hit a missing app. - task_key: Refund_Recommender_Stream depends_on: - task_key: Spark_Declarative_Pipeline @@ -274,8 +271,6 @@ targets: # --params "CATALOG=oleksandra,SKIP_EVAL=false" - name: SKIP_EVAL default: "true" - - name: COMPLAINT_AGENT_ENDPOINT_NAME - default: ${var.catalog}_complaint_agent - name: COMPLAINT_RATE default: "0.15" - name: SIMULATOR_SCHEMA @@ -595,10 +590,6 @@ targets: - name: PIPELINE_SCHEDULE_MINUTES default: "0" # Refund + complaints - - name: REFUND_AGENT_ENDPOINT_NAME - default: ${var.catalog}_refund_agent - - name: COMPLAINT_AGENT_ENDPOINT_NAME - default: ${var.catalog}_complaint_agent - name: COMPLAINT_RATE default: "0.15" # Operational Dashboard — Knowledge Assistant endpoints are auto-generated @@ -676,11 +667,10 @@ targets: notebook_path: ${workspace.root_path}/stages/refund_evaluation # Refund_Recommender_Stream calls w.jobs.run_now() on the - # streaming inference job whose UDF queries the Refund agent - # serving endpoint. The endpoint must exist and be READY before - # the stream's first batch fires — without the explicit - # Refund_Recommender_Agent edge the two tasks race and the - # stream can hit a missing endpoint. + # streaming inference job whose UDF queries the Refund agent App. + # The App must exist before the stream's first batch fires — + # without the explicit Refund_Recommender_Agent edge the two tasks + # race and the stream can hit a missing app. - task_key: Refund_Recommender_Stream depends_on: - task_key: Spark_Declarative_Pipeline diff --git a/jobs/complaint_agent_stream.ipynb b/jobs/complaint_agent_stream.ipynb index 398bc73..903666c 100644 --- a/jobs/complaint_agent_stream.ipynb +++ b/jobs/complaint_agent_stream.ipynb @@ -13,11 +13,22 @@ "cell_type": "code", "metadata": {}, "source": [ - "DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)\n", - "DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)\n", - "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "COMPLAINT_AGENT_ENDPOINT_NAME = dbutils.widgets.get(\"COMPLAINT_AGENT_ENDPOINT_NAME\")" + "\n", + "import os\n", + "import sys\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import app_request_context, complaint_agent_app_name, extract_response_text\n", + "\n", + "try:\n", + " COMPLAINT_AGENT_APP_NAME = dbutils.widgets.get(\"COMPLAINT_AGENT_APP_NAME\")\n", + "except Exception:\n", + " COMPLAINT_AGENT_APP_NAME = complaint_agent_app_name(CATALOG)\n", + "\n", + "_AGENT_APP_CONTEXT = app_request_context(app_name=COMPLAINT_AGENT_APP_NAME, dbutils=dbutils)\n", + "COMPLAINT_AGENT_APP_URL = _AGENT_APP_CONTEXT[\"url\"]\n", + "COMPLAINT_AGENT_APP_TOKEN = _AGENT_APP_CONTEXT[\"bearer_token\"]\n", + "print(f\"Complaint agent app: {COMPLAINT_AGENT_APP_NAME} ({COMPLAINT_AGENT_APP_URL})\")\n" ], "execution_count": null, "outputs": [] @@ -26,7 +37,7 @@ "cell_type": "code", "metadata": {}, "source": [ - "%pip install openai" + "%pip install -U -qqqq databricks-sdk requests\n" ], "execution_count": null, "outputs": [] @@ -44,51 +55,77 @@ "from pyspark.sql.functions import udf\n", "from pyspark.sql.window import Window\n", "\n", - "from openai import OpenAI\n", + "import requests\n", "\n", "# Per-call timeout in seconds. Without this, a stalled / scale-to-zero /\n", - "# deleted complaint-agent endpoint causes the underlying HTTP socket to hang\n", + "# deleted complaint-agent app causes the underlying HTTP socket to hang\n", "# indefinitely (we observed a single stream run hung for 3h47m on a previous\n", "# test session, blocking the cron queue). 30s is generous for a healthy agent\n", "# call (typical latency is 2-5s) but bounds the worst case.\n", "_CALL_TIMEOUT_S = 30\n", "\n", "# Per-batch inference cap and checkpoint path.\n", - "# Why this exists: previously this stream had no cap — `availableNow=True` would\n", + "# Why this exists: previously this stream had no cap \u2014 `availableNow=True` would\n", "# drain the entire `raw_complaints` backlog through the agent UDF in one go,\n", "# and a few hundred rows of backlog times 5-30s per call easily blew through\n", "# the 10-min `timeout_seconds` task budget. The cron then dropped subsequent\n", "# ticks (queue.enabled=False) and the backlog kept growing forever.\n", "#\n", - "# Sizing the cap (10) against measured endpoint latency:\n", - "# - Observed warm latency on this complaint-agent endpoint is 16-21s/call,\n", + "# Sizing the cap (10) against measured app latency:\n", + "# - Observed warm latency on this complaint-agent app is 16-21s/call,\n", "# not the 2-5s the prior comment assumed. With the previous cap of 20\n", "# and a 3-attempt retry loop, a single timed-out call cost up to 90s,\n", "# and we consistently blew the 600s budget at ~615s (every cron tick).\n", "# - With MAX=10 and no retries (see process_complaint below):\n", - "# worst case 10 × 30s = 300s (~5min slack)\n", - "# typical 10 × 18s = 180s (~7min slack)\n", - "# - Refund + support streams use 50 because their endpoints are faster;\n", + "# worst case 10 \u00d7 30s = 300s (~5min slack)\n", + "# typical 10 \u00d7 18s = 180s (~7min slack)\n", + "# - Refund + support streams use 50 because their agents are faster;\n", "# do not blindly copy that cap here without re-measuring.\n", "CHECKPOINT_PATH = f\"/Volumes/{CATALOG}/complaints/checkpoints/complaint_agent_stream\"\n", "# Sized for the 600s task timeout. The complaint agent (DSPy ReAct) measured\n", - "# at ~38s per call — markedly slower than the refunder (~12s) because it does\n", - "# more tool-call iterations per complaint. 5 × 38 ≈ 190s leaves room for\n", - "# cold-start jitter, Delta write, and is_first_run() probes; 10 × 38 = 380s\n", - "# + ~250s overhead was tipping over the 600s wall once the endpoint was\n", + "# at ~38s per call \u2014 markedly slower than the refunder (~12s) because it does\n", + "# more tool-call iterations per complaint. 5 \u00d7 38 \u2248 190s leaves room for\n", + "# cold-start jitter, Delta write, and is_first_run() probes; 10 \u00d7 38 = 380s\n", + "# + ~250s overhead was tipping over the 600s wall once the agent was\n", "# fully healthy.\n", "MAX_INFERENCES_PER_BATCH = 5\n", "\n", "\n", + "def _extract_agent_text(response):\n", + " if not isinstance(response, dict):\n", + " return str(response)\n", + " direct = response.get(\"output_text\")\n", + " if isinstance(direct, str) and direct:\n", + " return direct\n", + " output = response.get(\"output\") or []\n", + " if isinstance(output, dict):\n", + " output = [output]\n", + " for item in output:\n", + " if not isinstance(item, dict):\n", + " continue\n", + " content = item.get(\"content\")\n", + " if isinstance(content, str) and content:\n", + " return content\n", + " if isinstance(content, dict):\n", + " content = [content]\n", + " if isinstance(content, list):\n", + " for content_item in content:\n", + " if isinstance(content_item, dict):\n", + " text = content_item.get(\"text\")\n", + " if isinstance(text, str) and text:\n", + " return text\n", + " raise ValueError(\"Could not extract final response text\")\n", + "\n", + "\n", "def is_first_run():\n", - " \"\"\"True when no micro-batch has been committed yet — drives the\n", + " \"\"\"True when no micro-batch has been committed yet \u2014 drives the\n", " fake-it-till-up backfill path on cold start.\n", "\n", " Looks at the `commits/` subdirectory specifically. Spark Structured\n", " Streaming creates `offsets/`, `commits/`, `metadata`, `sources/` *before*\n", " the first call to `process_batch`, so the older heuristic of\n", " `os.listdir(CHECKPOINT_PATH) == 0` always returned False on the very first\n", - " batch — silently disabling the fast path and forcing the entire historical\n", + " batch \u2014 silently disabling the fast path and forcing the entire historical\n", " backlog through the real-inference codepath in a single micro-batch (which\n", " is what kept blowing past the 10-min task timeout). A non-hidden file\n", " under `commits/` is the unambiguous signal that this is not the first run.\n", @@ -102,7 +139,7 @@ " )\n", "\n", "\n", - "# ── Deterministic fallback responses ──────────────────────────────────────────\n", + "# \u2500\u2500 Deterministic fallback responses \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Used both during initial backfill (when the checkpoint is empty, so we want\n", "# to drain the entire historical source instantly) and as the per-batch\n", "# overflow path beyond MAX_INFERENCES_PER_BATCH. Schema matches the real\n", @@ -114,35 +151,35 @@ " \"decision\": \"credit\",\n", " \"credit_amount\": 5.0,\n", " \"rationale\": \"Order delivered after P75 threshold for the location.\",\n", - " \"customer_response\": \"Sorry about the late delivery — we've issued a $5.00 credit to your account.\",\n", + " \"customer_response\": \"Sorry about the late delivery \u2014 we've issued a $5.00 credit to your account.\",\n", " },\n", " {\n", " \"complaint_category\": \"missing_item\",\n", " \"decision\": \"credit\",\n", " \"credit_amount\": 7.5,\n", " \"rationale\": \"Missing-item complaints are auto-credited at item value when verified by order line.\",\n", - " \"customer_response\": \"We're sorry an item was missing from your order — we've added a $7.50 credit.\",\n", + " \"customer_response\": \"We're sorry an item was missing from your order \u2014 we've added a $7.50 credit.\",\n", " },\n", " {\n", " \"complaint_category\": \"food_quality\",\n", " \"decision\": \"escalate\",\n", " \"credit_amount\": 0.0,\n", " \"rationale\": \"Food-quality complaints route to a human reviewer with kitchen photos.\",\n", - " \"customer_response\": \"Thanks for the report — our quality team will follow up within 24 hours.\",\n", + " \"customer_response\": \"Thanks for the report \u2014 our quality team will follow up within 24 hours.\",\n", " },\n", " {\n", " \"complaint_category\": \"wrong_order\",\n", " \"decision\": \"credit\",\n", " \"credit_amount\": 10.0,\n", " \"rationale\": \"Wrong-order complaints trigger a same-day credit equal to subtotal cap.\",\n", - " \"customer_response\": \"Apologies for the mixup — a $10.00 credit has been applied to your account.\",\n", + " \"customer_response\": \"Apologies for the mixup \u2014 a $10.00 credit has been applied to your account.\",\n", " },\n", " {\n", " \"complaint_category\": \"other\",\n", " \"decision\": \"no_action\",\n", " \"credit_amount\": 0.0,\n", " \"rationale\": \"Insufficient signal for automated remediation; no credit.\",\n", - " \"customer_response\": \"Thanks for reaching out — we'll look into this and follow up if needed.\",\n", + " \"customer_response\": \"Thanks for reaching out \u2014 we'll look into this and follow up if needed.\",\n", " },\n", "]\n", "\n", @@ -156,44 +193,34 @@ "\n", "\n", "def process_complaint(complaint_text: str, order_id: str) -> str:\n", - " \"\"\"Process a complaint through the agent endpoint.\n", - "\n", - " Returns a JSON-encoded response. On any failure returns a default\n", - " \"escalate\" response so the streaming write can still progress \\u2014 we'd rather mark the complaint as needing human\n", - " review than hang the batch.\n", - " \"\"\"\n", - " client = OpenAI(\n", - " api_key=DATABRICKS_TOKEN,\n", - " base_url=f\"{DATABRICKS_HOST}/serving-endpoints\",\n", - " timeout=_CALL_TIMEOUT_S,\n", - " )\n", - "\n", + " \"\"\"Call the complaint agent app for a complaint row.\"\"\"\n", " default_response = json.dumps({\n", " \"order_id\": order_id,\n", " \"complaint_category\": \"other\",\n", " \"decision\": \"escalate\",\n", - " \"credit_amount\": 0.0,\n", - " \"rationale\": \"agent did not return valid JSON\",\n", - " \"customer_response\": \"We're reviewing your complaint and will get back to you shortly.\"\n", + " \"credit_amount\": None,\n", + " \"confidence\": None,\n", + " \"priority\": \"standard\",\n", + " \"rationale\": \"agent unavailable or did not return valid JSON\",\n", " })\n", "\n", - " # No retry: the dominant failure mode here is HTTP timeout, and each\n", - " # retry costs the full _CALL_TIMEOUT_S of wall time. Previous versions\n", - " # did 3 blind retries, making the worst-case per-row cost 90s and\n", - " # consistently blowing the 600s task budget (we observed ~615s every\n", - " # cron tick). For an availableNow catch-up stream the next tick will\n", - " # re-process anything we escalate here, so fail-fast keeps each batch\n", - " # bounded; the backlog drains on the next tick instead of stalling now.\n", " try:\n", - " response_obj = client.responses.create(\n", - " model=f\"{COMPLAINT_AGENT_ENDPOINT_NAME}\",\n", - " input=[{\n", - " \"role\": \"user\",\n", - " \"content\": f\"{complaint_text} (Order ID: {order_id})\"\n", - " }],\n", + " http_response = requests.post(\n", + " f\"{COMPLAINT_AGENT_APP_URL}/responses\",\n", + " headers={\n", + " \"Authorization\": f\"Bearer {COMPLAINT_AGENT_APP_TOKEN}\",\n", + " \"Content-Type\": \"application/json\",\n", + " },\n", + " json={\n", + " \"input\": [{\n", + " \"role\": \"user\",\n", + " \"content\": f\"{complaint_text} (Order ID: {order_id})\",\n", + " }]\n", + " },\n", " timeout=_CALL_TIMEOUT_S,\n", " )\n", - " response = response_obj.output[-1].content[0].text\n", + " http_response.raise_for_status()\n", + " response = _extract_agent_text(http_response.json())\n", " json.loads(response)\n", " return response\n", " except Exception:\n", @@ -248,7 +275,7 @@ " print(f\"Processing batch {batch_id}: {row_count} rows, first_run={first_run}\")\n", "\n", " if first_run:\n", - " # Drain the historical backlog with fakes — fast path.\n", + " # Drain the historical backlog with fakes \u2014 fast path.\n", " print(f\" -> First run detected, using fake responses for all {row_count} rows\")\n", " result_df = batch_df.select(\n", " F.col(\"complaint_id\"),\n", @@ -318,7 +345,7 @@ "cell_type": "code", "metadata": {}, "source": [ - "# Enable Change Data Feed for Lakebase sync — idempotent so we don't churn\n", + "# Enable Change Data Feed for Lakebase sync \u2014 idempotent so we don't churn\n", "# the Delta table history with a SET TBLPROPERTIES commit on every cron tick\n", "# (we observed 4+ no-op SET TBLPROPERTIES versions accumulating from this).\n", "table_name = f\"{CATALOG}.complaints.complaint_responses\"\n", @@ -374,12 +401,12 @@ " if any(marker in msg for marker in _RECOVERABLE_MARKERS):\n", " first_line = msg.splitlines()[0] if msg else \"\"\n", " print(\n", - " f\"⚠️ Streaming state unrecoverable ({first_line}). \"\n", + " f\"\u26a0\ufe0f Streaming state unrecoverable ({first_line}). \"\n", " f\"Clearing stale checkpoint at {CHECKPOINT_PATH} and restarting \"\n", " f\"from the LATEST version of the source (historical backfill skipped).\"\n", " )\n", " dbutils.fs.rm(CHECKPOINT_PATH, recurse=True)\n", - " print(f\"✅ Cleared {CHECKPOINT_PATH}\")\n", + " print(f\"\u2705 Cleared {CHECKPOINT_PATH}\")\n", " _run_stream(_build_complaint_source(starting_version=\"latest\"))\n", " else:\n", " raise" diff --git a/jobs/refund_recommender_stream.ipynb b/jobs/refund_recommender_stream.ipynb index f41370b..b9baeef 100644 --- a/jobs/refund_recommender_stream.ipynb +++ b/jobs/refund_recommender_stream.ipynb @@ -32,11 +32,22 @@ } }, "source": [ - "DATABRICKS_TOKEN = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)\n", - "DATABRICKS_HOST = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)\n", - "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "REFUND_AGENT_ENDPOINT_NAME = dbutils.widgets.get(\"REFUND_AGENT_ENDPOINT_NAME\")" + "\n", + "import os\n", + "import sys\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import app_request_context, extract_response_text, refund_agent_app_name\n", + "\n", + "try:\n", + " REFUND_AGENT_APP_NAME = dbutils.widgets.get(\"REFUND_AGENT_APP_NAME\")\n", + "except Exception:\n", + " REFUND_AGENT_APP_NAME = refund_agent_app_name(CATALOG)\n", + "\n", + "_AGENT_APP_CONTEXT = app_request_context(app_name=REFUND_AGENT_APP_NAME, dbutils=dbutils)\n", + "REFUND_AGENT_APP_URL = _AGENT_APP_CONTEXT[\"url\"]\n", + "REFUND_AGENT_APP_TOKEN = _AGENT_APP_CONTEXT[\"bearer_token\"]\n", + "print(f\"Refund agent app: {REFUND_AGENT_APP_NAME} ({REFUND_AGENT_APP_URL})\")\n" ], "execution_count": 0, "outputs": [] @@ -54,7 +65,7 @@ } }, "source": [ - "%pip install openai" + "%pip install -U -qqqq databricks-sdk requests\n" ], "execution_count": 0, "outputs": [] @@ -164,36 +175,62 @@ "import random\n", "import os\n", "\n", - "from openai import OpenAI\n", + "import requests\n", "\n", "# Configuration for inference capping\n", "CHECKPOINT_PATH = f\"/Volumes/{CATALOG}/recommender/checkpoints/refundrecommenderstream\"\n", "# Sized so that even at single-core executor parallelism the batch fits the\n", "# 600s task timeout: measured agent latency is ~12s per call (3 LLM round-trips\n", - "# for tool-calling), so 15 × 12 ≈ 180s leaves comfortable headroom for cold\n", + "# for tool-calling), so 15 \u00d7 12 \u2248 180s leaves comfortable headroom for cold\n", "# starts, Delta write, count(), and is_first_run() probes. Previous value of\n", - "# 50 sized for ~2s/call (fast-fail path while endpoint was UPDATE_FAILED); once\n", - "# the endpoint actually became READY, 50 × 12s = 600s pinned the wall and the\n", + "# 50 sized for ~2s/call (fast-fail path while app was unavailable); once\n", + "# the agent became healthy, 50 \u00d7 12s = 600s pinned the wall and the\n", "# stream began timing out under a healthy agent. Counter-intuitively, fixing\n", - "# the agent broke the stream — so the cap moves with reality.\n", + "# the agent broke the stream \u2014 so the cap moves with reality.\n", "MAX_INFERENCES_PER_BATCH = 15\n", "\n", - "# Per-call timeout (seconds) for the agent endpoint. Without this, a wedged\n", - "# endpoint can hang the UDF forever and pin the 10-min task timeout against\n", + "\n", + "def _extract_agent_text(response):\n", + " if not isinstance(response, dict):\n", + " return str(response)\n", + " direct = response.get(\"output_text\")\n", + " if isinstance(direct, str) and direct:\n", + " return direct\n", + " output = response.get(\"output\") or []\n", + " if isinstance(output, dict):\n", + " output = [output]\n", + " for item in output:\n", + " if not isinstance(item, dict):\n", + " continue\n", + " content = item.get(\"content\")\n", + " if isinstance(content, str) and content:\n", + " return content\n", + " if isinstance(content, dict):\n", + " content = [content]\n", + " if isinstance(content, list):\n", + " for content_item in content:\n", + " if isinstance(content_item, dict):\n", + " text = content_item.get(\"text\")\n", + " if isinstance(text, str) and text:\n", + " return text\n", + " raise ValueError(\"Could not extract final response text\")\n", + "\n", + "# Per-call timeout (seconds) for the agent app. Without this, a wedged\n", + "# app call can hang the UDF forever and pin the 10-min task timeout against\n", "# us. 30s bounds worst-case per-row wall time so MAX_INFERENCES_PER_BATCH\n", "# real calls (capped, Spark-parallel) fit well inside the 600s task budget.\n", - "# No retries inside the UDF — see get_chat_completion docstring for why.\n", + "# No retries inside the UDF \u2014 see get_chat_completion docstring for why.\n", "_CALL_TIMEOUT_S = 30\n", "\n", "def is_first_run():\n", - " \"\"\"True when no micro-batch has been committed yet — drives the\n", + " \"\"\"True when no micro-batch has been committed yet \u2014 drives the\n", " fake-it-till-up backfill path on cold start.\n", "\n", " Looks at the `commits/` subdirectory specifically. Spark Structured\n", " Streaming creates `offsets/`, `commits/`, `metadata`, `sources/` *before*\n", " the first call to `process_batch`, so the older heuristic of\n", " `os.listdir(CHECKPOINT_PATH) == 0` always returned False on the very first\n", - " batch — silently disabling the fast path and forcing the entire historical\n", + " batch \u2014 silently disabling the fast path and forcing the entire historical\n", " backlog through the real-inference codepath in a single micro-batch. A\n", " non-hidden file under `commits/` is the unambiguous signal that this is\n", " not the first run.\n", @@ -207,25 +244,12 @@ " )\n", "\n", "def get_chat_completion(content: str) -> str:\n", - " \"\"\"Call the refund agent endpoint for real inference.\n", - "\n", - " Single attempt only. Previous versions did 3 blind retries, which made\n", - " the worst-case per-row cost 3 × _CALL_TIMEOUT_S = 90s and blew the 600s\n", - " task budget under endpoint stress (see jobs/complaint_agent_stream.ipynb\n", - " cell defining process_complaint for the same lesson). For an\n", - " availableNow catch-up stream the next 10-min cron tick will re-process\n", - " anything that errored here, so fail-fast keeps each batch bounded.\n", - "\n", - " Failure policy:\n", - " - Endpoint unreachable / timeout → fake response (keep batch flowing).\n", - " - Endpoint returned invalid shape or non-JSON → default_response\n", - " with refund_class='error' so downstream consumers can spot it.\n", + " \"\"\"Call the refund agent app for real inference.\n", + "\n", + " Single attempt only. For an availableNow catch-up stream the next\n", + " 10-minute cron tick will re-process anything that errored here, so\n", + " fail-fast keeps each batch bounded.\n", " \"\"\"\n", - " client = OpenAI(\n", - " api_key=DATABRICKS_TOKEN,\n", - " base_url=f\"{DATABRICKS_HOST}/serving-endpoints\",\n", - " timeout=_CALL_TIMEOUT_S,\n", - " )\n", " default_response = json.dumps({\n", " \"refund_usd\": 0.0,\n", " \"refund_class\": \"error\",\n", @@ -233,31 +257,20 @@ " })\n", "\n", " try:\n", - " chat_completion = client.chat.completions.create(\n", - " model=f\"{REFUND_AGENT_ENDPOINT_NAME}\",\n", - " messages=[{\"role\": \"user\", \"content\": content}],\n", + " http_response = requests.post(\n", + " f\"{REFUND_AGENT_APP_URL}/responses\",\n", + " headers={\n", + " \"Authorization\": f\"Bearer {REFUND_AGENT_APP_TOKEN}\",\n", + " \"Content-Type\": \"application/json\",\n", + " },\n", + " json={\"input\": [{\"role\": \"user\", \"content\": content}]},\n", " timeout=_CALL_TIMEOUT_S,\n", " )\n", + " http_response.raise_for_status()\n", + " response = _extract_agent_text(http_response.json())\n", " except Exception:\n", - " # Endpoint unreachable — keep the batch flowing with a fake.\n", " return json.dumps(random.choice(fake_responses))\n", "\n", - " # OpenAI-compatible chat completion: response is in choices[0].message.content,\n", - " # NOT in .messages (older code path raised AttributeError once the agent was up).\n", - " #\n", - " # Defend broadly against the agent returning an unexpected response shape: the\n", - " # OpenAI client best-effort-maps the Databricks model-serving response, and we've\n", - " # observed it materializing `choices=None` (not just empty list) which raises\n", - " # `TypeError: 'NoneType' object is not subscriptable` on `.choices[0]`. A narrow\n", - " # `except (AttributeError, IndexError)` is NOT enough — the TypeError escapes the\n", - " # UDF, surfaces as PythonException, and crashes the whole foreachBatch\n", - " # saveAsTable call.\n", - " try:\n", - " choices = getattr(chat_completion, \"choices\", None) or []\n", - " response = choices[0].message.content if choices else \"\"\n", - " except Exception:\n", - " response = \"\"\n", - "\n", " try:\n", " json.loads(response)\n", " return response\n", @@ -324,8 +337,7 @@ " result_df = real_inference_df.union(fake_response_df)\n", "\n", " # Write to table\n", - " result_df.write.mode(\"append\").saveAsTable(f\"{CATALOG}.recommender.refund_recommendations\")\n", - "" + " result_df.write.mode(\"append\").saveAsTable(f\"{CATALOG}.recommender.refund_recommendations\")\n" ], "execution_count": null, "outputs": [] @@ -439,8 +451,8 @@ "# Recover automatically if the source Delta table id has changed since the\n", "# last run (DIFFERENT_DELTA_TABLE_READ_BY_STREAMING_SOURCE). On recovery we\n", "# rebuild the source stream with startingVersion=\"latest\" so we do NOT\n", - "# reprocess the entire recreated table (which is 90 days × 8 locations of\n", - "# canonical data — tens of millions of rows — and was timing out the cluster\n", + "# reprocess the entire recreated table (which is 90 days \u00d7 8 locations of\n", + "# canonical data \u2014 tens of millions of rows \u2014 and was timing out the cluster\n", "# in one giant availableNow micro-batch). Real-time data continues to flow\n", "# from the new version onward.\n", "def _run_stream(events):\n", @@ -473,12 +485,12 @@ " if any(marker in msg for marker in _RECOVERABLE_MARKERS):\n", " first_line = msg.splitlines()[0] if msg else \"\"\n", " print(\n", - " f\"⚠️ Streaming state unrecoverable ({first_line}). \"\n", + " f\"\u26a0\ufe0f Streaming state unrecoverable ({first_line}). \"\n", " f\"Clearing stale checkpoint at {CHECKPOINT_PATH} and restarting \"\n", " f\"from the LATEST version of the source (historical backfill skipped).\"\n", " )\n", " dbutils.fs.rm(CHECKPOINT_PATH, recurse=True)\n", - " print(f\"✅ Cleared {CHECKPOINT_PATH}\")\n", + " print(f\"\u2705 Cleared {CHECKPOINT_PATH}\")\n", " _run_stream(_build_delivered_events(starting_version=\"latest\"))\n", " else:\n", " raise" @@ -533,32 +545,6 @@ }, "widgetType": "text" } - }, - "REFUND_AGENT_ENDPOINT_NAME": { - "currentValue": "", - "nuid": "bf162244-8e75-4570-82f4-ef93955c6891", - "typedWidgetInfo": { - "autoCreated": false, - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "validationRegex": null, - "widgetDisplayType": "Text" - }, - "parameterDataType": "String" - }, - "widgetInfo": { - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "autoCreated": false, - "validationRegex": null, - "widgetType": "text" - }, - "widgetType": "text" - } } } }, diff --git a/stages/complaint_agent.ipynb b/stages/complaint_agent.ipynb index b28c645..157e486 100644 --- a/stages/complaint_agent.ipynb +++ b/stages/complaint_agent.ipynb @@ -3,7 +3,11 @@ { "cell_type": "markdown", "metadata": {}, - "source": "#### complaint agent\n\nBuilds and ships an order-complaint agent using DSPy: author Unity Catalog tools, assemble the DSPy ReAct workflow, evaluate it, and promote the packaged model into production.", + "source": [ + "#### complaint agent\n", + "\n", + "Builds and ships an order-complaint agent using DSPy: author Unity Catalog tools, assemble the DSPy ReAct workflow in the Databricks App source, and deploy the app with MLflow AgentServer.\n" + ], "id": "cell-0" }, { @@ -198,13 +202,11 @@ "metadata": {}, "source": [ "%sql\n", - "-- USE CATALOG is needed in addition to USE SCHEMA + EXECUTE so the serving\n", - "-- endpoint's auto-generated SP can traverse the catalog to reach the UC\n", - "-- functions at model-load time. Normally granted by the root data stage\n", - "-- (canonical_data/raw_data), but repeated here so this stage is self-\n", - "-- sufficient if run standalone or against a pre-existing catalog.\n", + "-- USE CATALOG is needed in addition to USE SCHEMA + EXECUTE so the\n", + "-- Databricks App service principal can traverse the catalog to reach the\n", + "-- UC functions at inference time.\n", "GRANT USE CATALOG ON CATALOG ${CATALOG} TO `account users`;\n", - "GRANT USE SCHEMA ON SCHEMA ${CATALOG}.ai TO `account users`;" + "GRANT USE SCHEMA ON SCHEMA ${CATALOG}.ai TO `account users`;\n" ], "execution_count": null, "outputs": [] @@ -214,8 +216,7 @@ "metadata": {}, "source": [ "%sql\n", - "-- Grant EXECUTE to all workspace principals so the serving endpoint SP\n", - "-- (created by agents.deploy) can call these tools at inference time.\n", + "-- Grant EXECUTE so app callers and the agent app SP can call these tools at inference time.\n", "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_order_overview TO `account users`;\n", "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_order_timing TO `account users`;\n", "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_location_timings TO `account users`;" @@ -226,13 +227,23 @@ { "cell_type": "markdown", "metadata": {}, - "source": "#### Model\n\n- Install DSPy, Databricks agent packages, and restart Python for a clean runtime.\n- Capture widget inputs (`CATALOG`, `LLM_MODEL`) and create an MLflow dev experiment for trace logging.\n- Define a templated `%%writefilev` magic that emits files with notebook variable substitution.\n- Materialize `agent.py` containing the DSPy ReAct complaint workflow wired to UC SQL tools and the chosen LLM endpoint.\n- Pull a delivered `order_id` sample and build the MLflow model signature/resources for logging.", + "source": [ + "#### App Agent\n", + "\n", + "- Install orchestration dependencies and restart Python for a clean runtime.\n", + "- Capture widget inputs (`CATALOG`, `LLM_MODEL`) and resolve the deterministic Databricks App name.\n", + "- Use `../apps/complaint-agent` as the source of truth for the DSPy ReAct complaint workflow.\n", + "- Treat `LLM_MODEL` as a Unity AI Gateway endpoint name; no custom-agent LLM calls use legacy model-serving invocation routes.\n" + ], "id": "dq5ml4wp6v" }, { "cell_type": "code", "metadata": {}, - "source": "%pip install -U -qqqq typing_extensions dspy-ai mlflow unitycatalog-openai[databricks] openai databricks-sdk databricks-agents pydantic\n%restart_python", + "source": [ + "%pip install -U -qqqq mlflow[databricks] databricks-sdk requests openai\n", + "%restart_python\n" + ], "execution_count": null, "outputs": [], "id": "bd3tu3r2gso" @@ -242,7 +253,15 @@ "metadata": {}, "source": [ "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "LLM_MODEL = dbutils.widgets.get(\"LLM_MODEL\")" + "LLM_MODEL = dbutils.widgets.get(\"LLM_MODEL\")\n", + "\n", + "import sys\n", + "sys.path.append('../utils')\n", + "from agent_app_client import complaint_agent_app_name\n", + "\n", + "APP_NAME = complaint_agent_app_name(CATALOG)\n", + "UC_MODEL_NAME = f\"{CATALOG}.ai.complaint_agent_app\"\n", + "print(f\"Complaint agent app: {APP_NAME}\")\n" ], "execution_count": null, "outputs": [], @@ -261,7 +280,7 @@ "# set_experiment creates the experiment if it doesn't exist, or activates it if it does\n", "dev_experiment = mlflow.set_experiment(dev_experiment_name)\n", "dev_experiment_id = dev_experiment.experiment_id\n", - "print(f\"✅ Using dev experiment: {dev_experiment_name} (ID: {dev_experiment_id})\")\n", + "print(f\"\u2705 Using dev experiment: {dev_experiment_name} (ID: {dev_experiment_id})\")\n", "\n", "# Add experiment to UC state for cleanup\n", "import sys\n", @@ -273,710 +292,350 @@ " \"name\": dev_experiment_name\n", "}\n", "add(CATALOG, \"experiments\", experiment_data)\n", - "print(f\"✅ Added dev experiment to UC state\")" + "print(f\"\u2705 Added dev experiment to UC state\")" ], "execution_count": null, "outputs": [], "id": "ktoahik8tl" }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "import re\n", - "import os\n", - "from IPython.core.magic import register_cell_magic\n", - "\n", - "# Records absolute paths of files written by `%%writefilev`, keyed by the\n", - "# filename argument. The import cell below imports `from agent import LLM_MODEL`.\n", - "#\n", - "# Earlier versions wrote to `os.path.abspath(filename)` (i.e. CWD), but CWD\n", - "# on serverless notebooks is usually a `/Workspace/Users/...` path, and the\n", - "# workspace-files-as-Python-modules feature is not reliably wired through\n", - "# on serverless — the file ends up on disk, `os.path.exists()` returns True,\n", - "# but `import agent` still raises `ModuleNotFoundError`. Pinning the write\n", - "# dir to a regular local-disk path (`/local_disk0/tmp/...`, falling back to\n", - "# `/tmp/...`) sidesteps the workspace-files importer entirely.\n", - "_WRITEFILEV_ABS_PATHS = {}\n", - "\n", - "_WRITEFILEV_DIR = \"/local_disk0/tmp/caspers_writefilev\"\n", - "if not os.path.isdir(\"/local_disk0\"):\n", - " _WRITEFILEV_DIR = \"/tmp/caspers_writefilev\"\n", - "os.makedirs(_WRITEFILEV_DIR, exist_ok=True)\n", - "\n", - "@register_cell_magic\n", - "def writefilev(line, cell):\n", - " \"\"\"\n", - " %%writefilev file.py\n", - " Allows {{var}} substitutions while leaving normal {} intact.\n", - "\n", - " Writes to a stable local-disk path (NOT CWD) so subsequent\n", - " `from import ...` always succeeds, even on serverless\n", - " where CWD is a /Workspace path.\n", - " \"\"\"\n", - " filename = line.strip()\n", - "\n", - " def replacer(match):\n", - " expr = match.group(1)\n", - " return str(eval(expr, globals(), locals()))\n", - "\n", - " content = re.sub(r\"\\{\\{(.*?)\\}\\}\", replacer, cell)\n", + "#### Production Experiment\n", "\n", - " abs_path = os.path.join(_WRITEFILEV_DIR, filename)\n", - " with open(abs_path, \"w\") as f:\n", - " f.write(content)\n", - " _WRITEFILEV_ABS_PATHS[filename] = abs_path\n", - " print(f\"Wrote file with substitutions: {abs_path}\")" - ], - "execution_count": null, - "outputs": [], - "id": "bruu85upqq5" + "Create the production MLflow experiment used by the Databricks App runtime for traces." + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "%%writefilev agent.py\n", - "import warnings\n", - "from typing import Optional, Literal\n", - "from uuid import uuid4\n", - "from pydantic import BaseModel, Field, field_validator, ValidationError\n", - "\n", - "warnings.filterwarnings(\"ignore\", message=\".*Ignoring the default notebook Spark session.*\")\n", - "\n", - "import dspy\n", "import mlflow\n", - "from unitycatalog.ai.core.base import get_uc_function_client\n", - "from mlflow.pyfunc import ResponsesAgent\n", - "from mlflow.types.responses import (\n", - " ResponsesAgentRequest,\n", - " ResponsesAgentResponse,\n", - ")\n", - "\n", - "# Enable DSPy autologging for automatic trace capture\n", - "mlflow.dspy.autolog(log_traces=True)\n", - "\n", - "LLM_MODEL = \"{{LLM_MODEL}}\"\n", - "CATALOG = \"{{CATALOG}}\"\n", - "\n", - "# Configure DSPy with Databricks LM.\n", - "# - num_retries makes litellm back off and retry every individual LLM call\n", - "# on REQUEST_LIMIT_EXCEEDED, so a single 429 inside DSPy's ReAct loop\n", - "# doesn't fail the whole agent invocation.\n", - "# - cache=False is required for the retry loop in ComplaintTriageModule.forward\n", - "# to actually retry — with cache=True every retry returns the same cached\n", - "# response and the same ValidationError fires every time.\n", - "lm = dspy.LM(f'databricks/{LLM_MODEL}', max_tokens=2000, num_retries=20, cache=False)\n", - "dspy.configure(lm=lm)\n", - "\n", - "# Initialize UC function client\n", - "uc_client = get_uc_function_client()\n", - "\n", - "\n", - "class ComplaintResponse(BaseModel):\n", - " \"\"\"Structured output for complaint triage decisions.\"\"\"\n", - " order_id: str\n", - " complaint_category: Literal[\"delivery_delay\", \"missing_items\", \"food_quality\", \"service_issue\", \"billing\", \"other\"] = Field(\n", - " description=\"Exactly ONE primary complaint category\"\n", - " )\n", - " decision: Literal[\"suggest_credit\", \"escalate\"]\n", - " credit_amount: Optional[float] = None\n", - " confidence: Optional[Literal[\"high\", \"medium\", \"low\"]] = None\n", - " priority: Optional[Literal[\"standard\", \"urgent\"]] = None\n", - " rationale: str\n", - " \n", - " @field_validator('complaint_category', mode='before')\n", - " @classmethod\n", - " def parse_category(cls, v):\n", - " \"\"\"Extract first valid category if multiple provided.\"\"\"\n", - " if not isinstance(v, str):\n", - " return v\n", - " \n", - " valid_categories = [\"delivery_delay\", \"missing_items\", \"food_quality\", \"service_issue\", \"billing\", \"other\"]\n", - " v_lower = v.lower().strip()\n", - " \n", - " # Exact match\n", - " if v_lower in valid_categories:\n", - " return v_lower\n", - " \n", - " # Find first valid category in string\n", - " for cat in valid_categories:\n", - " if cat in v_lower:\n", - " return cat\n", - " \n", - " return \"other\"\n", - " \n", - " @field_validator('confidence', mode='before')\n", - " @classmethod\n", - " def parse_confidence(cls, v):\n", - " \"\"\"Ensure valid confidence value.\"\"\"\n", - " if v is None or (isinstance(v, str) and v.lower() == \"null\"):\n", - " return None\n", - " if isinstance(v, str):\n", - " v_lower = v.lower().strip()\n", - " if v_lower in [\"high\", \"medium\", \"low\"]:\n", - " return v_lower\n", - " return \"medium\"\n", - " return v\n", - " \n", - " @field_validator('priority', mode='before')\n", - " @classmethod\n", - " def parse_priority(cls, v):\n", - " \"\"\"Ensure valid priority value.\"\"\"\n", - " if v is None or (isinstance(v, str) and v.lower() == \"null\"):\n", - " return None\n", - " if isinstance(v, str):\n", - " v_lower = v.lower().strip()\n", - " if v_lower in [\"standard\", \"urgent\"]:\n", - " return v_lower\n", - " return \"standard\"\n", - " return v\n", - "\n", - "\n", - "class ComplaintTriage(dspy.Signature):\n", - " \"\"\"Analyze customer complaints for Casper's Kitchens and recommend triage actions.\n", - " \n", - " Process:\n", - " 1. Extract order_id from complaint\n", - " 2. Use get_order_overview(order_id) for order details and items\n", - " 3. Use get_order_timing(order_id) for delivery timing\n", - " 4. For delays, use get_location_timings(location) for percentile benchmarks\n", - " 5. Make data-backed decision\n", - " \n", - " Decision Framework:\n", - " \n", - " SUGGEST_CREDIT (with credit_amount and confidence):\n", - " - Delivery delays: Compare actual delivery time to location percentiles\n", - " * P99: Suggest 25% of order total (high confidence)\n", - " - Missing items: Use actual item prices from order data when available\n", - " * Verify claimed item exists in order (affects confidence)\n", - " * Use real costs from order data, or estimate $8-12 per item if unavailable\n", - " - Food quality: 20-40% of order total based on severity\n", - " * Minor issues (slightly cold, minor preparation issue): 20% (medium confidence)\n", - " * Major issues (completely inedible, wrong preparation, health concern): 40% (high confidence)\n", - " * Vague complaints (\"bad\", \"gross\"): escalate instead\n", - " \n", - " ESCALATE (with priority):\n", - " - priority=\"standard\": Vague complaints, missing data, billing issues, service complaints\n", - " - priority=\"urgent\": Legal threats, health/safety concerns, suspected fraud, abusive language\n", - " \n", - " Output Requirements:\n", - " - For suggest_credit: credit_amount is REQUIRED and must be a number (can be 0.0 if no credit warranted), confidence is REQUIRED, priority must be null\n", - " - For escalate: priority is REQUIRED, credit_amount and confidence must be null\n", - " - complaint_category: Choose EXACTLY ONE category (the primary one)\n", - " - Rationale must cite specific evidence (delivery times, percentiles, item verification, order total)\n", - " - Rationale should be detailed but under 150 words\n", - " - Round credit amounts to nearest $0.50\n", - " - Confidence: high (strong data), medium (reasonable inference), low (weak/contradictory)\n", - " \"\"\"\n", - " \n", - " complaint: str = dspy.InputField(desc=\"Customer complaint text\")\n", - " order_id: str = dspy.OutputField(desc=\"Extracted order ID\")\n", - " complaint_category: str = dspy.OutputField(desc=\"EXACTLY ONE category: delivery_delay, missing_items, food_quality, service_issue, billing, or other\")\n", - " decision: str = dspy.OutputField(desc=\"EXACTLY ONE: suggest_credit or escalate\")\n", - " credit_amount: str = dspy.OutputField(desc=\"If suggest_credit: MUST be a number (e.g., 0.0, 10.5). If escalate: null\")\n", - " confidence: str = dspy.OutputField(desc=\"If suggest_credit: EXACTLY ONE of high, medium, low. If escalate: null\")\n", - " priority: str = dspy.OutputField(desc=\"If escalate: EXACTLY ONE of standard or urgent. If suggest_credit: null\")\n", - " rationale: str = dspy.OutputField(desc=\"Data-focused justification citing specific evidence\")\n", - "\n", - "\n", - "# Unity Catalog tool wrappers\n", - "def get_order_overview(order_id: str) -> str:\n", - " \"\"\"Get order details including items, location, and customer info.\"\"\"\n", - " result = uc_client.execute_function(\n", - " f\"{CATALOG}.ai.get_order_overview\",\n", - " {\"oid\": order_id}\n", - " )\n", - " return str(result.value)\n", - "\n", "\n", - "def get_order_timing(order_id: str) -> str:\n", - " \"\"\"Get timing information for a specific order.\"\"\"\n", - " result = uc_client.execute_function(\n", - " f\"{CATALOG}.ai.get_order_timing\",\n", - " {\"oid\": order_id}\n", - " )\n", - " return str(result.value)\n", - "\n", - "\n", - "def get_location_timings(location: str) -> str:\n", - " \"\"\"Get delivery time percentiles for a specific location.\"\"\"\n", - " result = uc_client.execute_function(\n", - " f\"{CATALOG}.ai.get_location_timings\",\n", - " {\"loc\": location}\n", - " )\n", - " return str(result.value)\n", - "\n", - "\n", - "class ComplaintTriageModule(dspy.Module):\n", - " \"\"\"DSPy module for complaint triage with tool calling.\"\"\"\n", - " \n", - " def __init__(self):\n", - " super().__init__()\n", - " self.react = dspy.ReAct(\n", - " signature=ComplaintTriage,\n", - " tools=[get_order_overview, get_order_timing, get_location_timings],\n", - " max_iters=10\n", - " )\n", - " \n", - " def forward(self, complaint: str, max_retries: int = 2) -> ComplaintResponse:\n", - " \"\"\"Process complaint and return structured triage decision with retry on validation failure.\"\"\"\n", - " \n", - " for attempt in range(max_retries + 1):\n", - " try:\n", - " result = self.react(complaint=complaint)\n", - " \n", - " # Parse credit_amount. $0 credit is a valid suggest_credit\n", - " # response (per the eval scorer: \"credit_amount of $0 is valid\n", - " # when rationale indicates no issue\"), so when the LLM emits\n", - " # decision='suggest_credit' but no parseable credit_amount we\n", - " # treat it as $0 rather than failing the whole call. Other\n", - " # malformed fields (Literal types) still raise via Pydantic\n", - " # and trigger the retry loop below.\n", - " credit_amount = None\n", - " if result.credit_amount and result.credit_amount.lower() != \"null\":\n", - " try:\n", - " credit_amount = float(result.credit_amount)\n", - " except (ValueError, TypeError):\n", - " credit_amount = None\n", - "\n", - " if result.decision == \"suggest_credit\" and credit_amount is None:\n", - " credit_amount = 0.0\n", - " \n", - " # Construct Pydantic model - field validators run here\n", - " return ComplaintResponse(\n", - " order_id=result.order_id,\n", - " complaint_category=result.complaint_category,\n", - " decision=result.decision,\n", - " credit_amount=credit_amount,\n", - " confidence=result.confidence,\n", - " priority=result.priority,\n", - " rationale=result.rationale\n", - " )\n", - " \n", - " except (ValidationError, ValueError) as e:\n", - " if attempt < max_retries:\n", - " # Retry - DSPy will regenerate with potentially different output\n", - " continue\n", - " else:\n", - " # Final attempt failed - re-raise\n", - " raise\n", - "\n", - "\n", - "class DSPyComplaintAgent(ResponsesAgent):\n", - " \"\"\"ResponsesAgent wrapper for DSPy complaint triage module.\"\"\"\n", - " \n", - " def __init__(self):\n", - " self.module = ComplaintTriageModule()\n", - " \n", - " def predict(self, request: ResponsesAgentRequest) -> ResponsesAgentResponse:\n", - " \"\"\"Process complaint request and return structured response.\"\"\"\n", - " complaint = None\n", - " for msg in request.input:\n", - " msg_dict = msg.model_dump() if hasattr(msg, \"model_dump\") else msg\n", - " if msg_dict.get(\"role\") == \"user\":\n", - " complaint = msg_dict.get(\"content\", \"\")\n", - " break\n", - " \n", - " if not complaint:\n", - " raise ValueError(\"No user message found in request\")\n", - " \n", - " result = self.module(complaint=complaint)\n", - " \n", - " return ResponsesAgentResponse(\n", - " output=[\n", - " self.create_text_output_item(\n", - " text=result.model_dump_json(),\n", - " id=str(uuid4())\n", - " )\n", - " ],\n", - " custom_outputs=request.custom_inputs\n", - " )\n", + "prod_experiment_name = f\"/Shared/{CATALOG}_complaint_agent_prod\"\n", + "prod_experiment = mlflow.set_experiment(prod_experiment_name)\n", + "prod_experiment_id = prod_experiment.experiment_id\n", + "print(f\"Using prod experiment: {prod_experiment_name} (ID: {prod_experiment_id})\")\n", "\n", + "import sys\n", + "sys.path.append('../utils')\n", + "from uc_state import add\n", "\n", - "# Initialize agent\n", - "AGENT = DSPyComplaintAgent()\n", - "mlflow.models.set_model(AGENT)" - ], - "execution_count": null, - "outputs": [], - "id": "w2c70l4ca8" + "add(CATALOG, \"experiments\", {\n", + " \"experiment_id\": prod_experiment_id,\n", + " \"name\": prod_experiment_name,\n", + "})\n", + "print(\"Added prod experiment to UC state\")\n" + ] }, { - "cell_type": "code", + "cell_type": "markdown", "metadata": {}, "source": [ - "# Get an actual order_id for the input example. In continuous-pipeline\n", - "# mode the lakeflow stage finishes before any data has flowed through, so\n", - "# poll up to 6 minutes for the first delivered event (same pattern as\n", - "# stages/refunder_agent.ipynb).\n", - "import time\n", - "\n", - "sample_order_id = None\n", - "for attempt in range(12):\n", - " rows = spark.sql(f\"\"\"\n", - " SELECT order_id\n", - " FROM {CATALOG}.lakeflow.all_events\n", - " WHERE event_type='delivered'\n", - " LIMIT 1\n", - " \"\"\").collect()\n", - " if rows:\n", - " sample_order_id = rows[0]['order_id']\n", - " break\n", - " print(f\"No delivered events yet (attempt {attempt+1}/12). Waiting 30s for pipeline data...\")\n", - " time.sleep(30)\n", + "#### Prompt Registry\n", "\n", - "if not sample_order_id:\n", - " raise RuntimeError(\n", - " f\"No delivered events found in {CATALOG}.lakeflow.all_events after 6 minutes. \"\n", - " \"Ensure the Canonical_Data and Lakeflow pipeline stages completed and processed data.\"\n", - " )" - ], - "execution_count": null, - "outputs": [], - "id": "dd6gjrp4mx6" + "Seed the prompt registry from the Databricks App source so prompt governance stays with the deployed app code." + ] }, { "cell_type": "code", - "metadata": {}, - "source": [ - "assert sample_order_id is not None\n", - "print(sample_order_id)" - ], "execution_count": null, - "outputs": [], - "id": "ih3tt5qeb5" - }, - { - "cell_type": "code", "metadata": {}, + "outputs": [], "source": [ - "import mlflow\n", - "import sys\n", "import os\n", + "import re\n", + "import sys\n", "\n", - "# Use the absolute path captured by `%%writefilev` (see cell 13). This is\n", - "# stable across CWD drift / kernel restarts; the older `sys.path.append(os.getcwd())`\n", - "# pattern was failing with `ModuleNotFoundError: No module named 'agent'` when\n", - "# CWD changed between the writefile cell and this one.\n", - "_agent_py_path = _WRITEFILEV_ABS_PATHS.get(\"agent.py\")\n", - "if _agent_py_path and os.path.exists(_agent_py_path):\n", - " _agent_dir = os.path.dirname(_agent_py_path)\n", - " print(f\"Importing agent from: {_agent_py_path}\")\n", - "else:\n", - " # Defensive fallback: try the same locations the old code did, plus\n", - " # /databricks/driver which is the typical CWD on classic clusters,\n", - " # and the new pinned local-disk dir used by `%%writefilev`.\n", - " _candidate_dirs = [\n", - " os.getcwd(),\n", - " \"/databricks/driver\",\n", - " \"/local_disk0/tmp/caspers_writefilev\",\n", - " \"/local_disk0/tmp\",\n", - " \"/tmp/caspers_writefilev\",\n", - " \"/tmp\",\n", - " ]\n", - " try:\n", - " _nb_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()\n", - " _candidate_dirs.append(os.path.dirname(_nb_path))\n", - " except Exception:\n", - " pass\n", - " _agent_dir = next((d for d in _candidate_dirs if os.path.exists(os.path.join(d, \"agent.py\"))), None)\n", - " if _agent_dir is None:\n", - " raise FileNotFoundError(\n", - " f\"agent.py not found. _WRITEFILEV_ABS_PATHS={_WRITEFILEV_ABS_PATHS}, \"\n", - " f\"CWD={os.getcwd()}, candidates tried: {_candidate_dirs}\"\n", - " )\n", - " _agent_py_path = os.path.join(_agent_dir, \"agent.py\")\n", - " print(f\"Importing agent from fallback dir: {_agent_py_path}\")\n", - "\n", - "if _agent_dir not in sys.path:\n", - " sys.path.insert(0, _agent_dir)\n", + "sys.path.append('../utils')\n", + "from prompt_registry import seed_prompt_history\n", "\n", - "# Invalidate any cached `agent` module that may be left over from a prior\n", - "# attempt on the same warehouse (e.g. after a transient failure).\n", - "sys.modules.pop(\"agent\", None)\n", + "_agent_py_path = os.path.abspath(\"../apps/complaint-agent/agent.py\")\n", + "with open(_agent_py_path) as f:\n", + " _agent_py = f.read()\n", "\n", - "from agent import LLM_MODEL\n", - "from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint\n", - "from pkg_resources import get_distribution\n", + "_match = re.search(\n", + " r'class ComplaintTriage\\(dspy\\.Signature\\):\\s*\"\"\"(.*?)\"\"\"',\n", + " _agent_py,\n", + " re.DOTALL,\n", + ")\n", + "if not _match:\n", + " raise RuntimeError(\"Could not extract ComplaintTriage docstring from complaint app source\")\n", + "_current_complaint_prompt = _match.group(1).strip()\n", "\n", - "resources = [DatabricksServingEndpoint(endpoint_name=LLM_MODEL)]\n", - "# Add UC function resources\n", - "uc_tool_names = [\n", - " f\"{CATALOG}.ai.get_order_overview\",\n", - " f\"{CATALOG}.ai.get_order_timing\",\n", - " f\"{CATALOG}.ai.get_location_timings\",\n", - "]\n", - "for func_name in uc_tool_names:\n", - " resources.append(DatabricksFunction(function_name=func_name))\n", + "_COMPLAINT_V1 = (\n", + " \"Analyze customer complaints and recommend a triage action.\\n\\n\"\n", + " \"For each complaint:\\n\"\n", + " \"- Extract the order_id\\n\"\n", + " \"- Choose a decision: suggest_credit OR escalate\\n\"\n", + " \"- Provide a rationale\\n\\n\"\n", + " \"Output: order_id, decision, rationale.\"\n", + ")\n", + "_COMPLAINT_V2 = (\n", + " \"Analyze customer complaints for Casper's Kitchens and recommend triage actions.\\n\\n\"\n", + " \"Process:\\n\"\n", + " \"1. Extract order_id from complaint\\n\"\n", + " \"2. Use get_order_overview(order_id) for order details and items\\n\"\n", + " \"3. Make a decision: suggest_credit (with credit_amount) or escalate (with priority)\\n\\n\"\n", + " \"Decision Framework:\\n\\n\"\n", + " \"SUGGEST_CREDIT:\\n\"\n", + " \"- Delivery delays: 10-20% of order total\\n\"\n", + " \"- Missing items: estimate $10 per item\\n\"\n", + " \"- Food quality: 20-40% of order total based on severity\\n\\n\"\n", + " \"ESCALATE:\\n\"\n", + " \"- priority=\\\"standard\\\": vague complaints, billing issues, service complaints\\n\"\n", + " \"- priority=\\\"urgent\\\": legal threats, health/safety concerns, abusive language\\n\\n\"\n", + " \"Output: order_id, complaint_category, decision, credit_amount, confidence, priority, rationale.\"\n", + ")\n", "\n", - "input_example = {\n", - " \"input\": [\n", - " {\n", - " \"role\": \"user\",\n", - " \"content\": f\"My order was really late! Order ID: {sample_order_id}\"\n", - " }\n", - " ]\n", + "_common_tags = {\n", + " \"agent\": \"complaint\",\n", + " \"stage\": \"complaint_agent\",\n", + " \"app_name\": APP_NAME,\n", + " \"uc_model\": UC_MODEL_NAME,\n", + " \"consumed_via\": \"DSPy Signature docstring in Databricks App source\",\n", "}\n", "\n", - "# Create custom conda environment with mlflow explicitly specified\n", - "conda_env = {\n", - " \"channels\": [\"conda-forge\"],\n", - " \"dependencies\": [\n", - " \"python=3.11\",\n", - " \"pip\",\n", + "seed_prompt_history(\n", + " spark=spark,\n", + " catalog=CATALOG,\n", + " name=\"complaint_system\",\n", + " historical=[\n", " {\n", - " \"pip\": [\n", - " \"mlflow==3.6\",\n", - " f\"typing_extensions=={get_distribution('typing_extensions').version}\",\n", - " f\"dspy-ai=={get_distribution('dspy-ai').version}\",\n", - " f\"unitycatalog-openai[databricks]=={get_distribution('unitycatalog-openai').version}\",\n", - " f\"pydantic=={get_distribution('pydantic').version}\",\n", - " ]\n", - " }\n", + " \"template\": _COMPLAINT_V1,\n", + " \"commit_message\": \"v1: minimal triage, decision + rationale only, no tool use (demo history seed)\",\n", + " \"tags\": _common_tags,\n", + " },\n", + " {\n", + " \"template\": _COMPLAINT_V2,\n", + " \"commit_message\": \"v2: added decision framework with categories + flat credit heuristics (demo history seed)\",\n", + " \"tags\": _common_tags,\n", + " },\n", " ],\n", - " \"name\": \"mlflow-env\"\n", - "}\n", - "\n", - "with mlflow.start_run():\n", - " logged_agent_info = mlflow.pyfunc.log_model(\n", - " name=\"complaint_agent\",\n", - " python_model=_agent_py_path,\n", - " input_example=input_example,\n", - " resources=resources,\n", - " conda_env=conda_env,\n", - " )\n", - "\n", - "mlflow.set_active_model(model_id = logged_agent_info.model_id)" - ], - "execution_count": null, - "outputs": [], - "id": "23o4h7j6amzh" + " current={\n", + " \"template\": _current_complaint_prompt,\n", + " \"commit_message\": \"v3 (production): percentile-based credit calc + timing tools, deployed as Databricks App\",\n", + " \"tags\": {**_common_tags, \"deployment_kind\": \"databricks_app\"},\n", + " },\n", + ")\n" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "#### Log the Agent to Unity Catalog\n\n- Point MLflow at the Unity Catalog registry and name the artifact `${CATALOG}.ai.complaint_agent`.\n- Register the run-produced model so versioned deployments can be promoted through UC stages.", - "id": "ttse3kj3pcj" - }, - { - "cell_type": "code", - "metadata": {}, "source": [ - "from databricks.sdk import WorkspaceClient\n", - "from databricks.sdk.service.serving import EndpointStateReady\n", + "#### Deploy Agent App\n", "\n", - "mlflow.set_registry_uri(\"databricks-uc\")\n", - "\n", - "UC_MODEL_NAME = f\"{CATALOG}.ai.complaint_agent\"\n", - "endpoint_name = dbutils.widgets.get(\"COMPLAINT_AGENT_ENDPOINT_NAME\")\n", - "\n", - "\n", - "def _endpoint_already_serving(name: str, uc_model_name: str) -> bool:\n", - " \"\"\"Return True iff a serving endpoint is READY and already serving uc_model_name.\n", - "\n", - " Used to short-circuit register_model + agents.deploy on re-runs of this\n", - " stage when the endpoint from a previous run is still healthy — saves ~15\n", - " minutes of cold container build + serving provisioning. To force a fresh\n", - " deploy after editing agent code, delete the endpoint and rerun the stage.\n", - " \"\"\"\n", - " try:\n", - " ep = WorkspaceClient().serving_endpoints.get(name)\n", - " except Exception:\n", - " return False\n", - " if not ep.state or ep.state.ready != EndpointStateReady.READY:\n", - " return False\n", - " cfg = getattr(ep, \"config\", None) or getattr(ep, \"pending_config\", None)\n", - " if not cfg:\n", - " return False\n", - " served = []\n", - " for se in (getattr(cfg, \"served_entities\", None) or []):\n", - " n = getattr(se, \"entity_name\", None)\n", - " if n:\n", - " served.append(n)\n", - " for sm in (getattr(cfg, \"served_models\", None) or []):\n", - " n = getattr(sm, \"model_name\", None)\n", - " if n:\n", - " served.append(n)\n", - " return uc_model_name in served\n", - "\n", - "\n", - "_reuse_endpoint = _endpoint_already_serving(endpoint_name, UC_MODEL_NAME)\n", - "\n", - "if _reuse_endpoint:\n", - " print(\n", - " f\"\\u267b\\ufe0f Endpoint {endpoint_name} is already READY and serving {UC_MODEL_NAME}; \"\n", - " f\"skipping register_model + agents.deploy (saves ~15 min). \"\n", - " f\"Delete the endpoint to force a fresh deploy.\"\n", - " )\n", - " uc_registered_model_info = None\n", - "else:\n", - " # register the model to UC\n", - " uc_registered_model_info = mlflow.register_model(\n", - " model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME\n", - " )" - ], - "execution_count": null, - "outputs": [], - "id": "y8lfco9zzn" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "#### Deploy the Agent to Model Serving\n\n- Create a production MLflow experiment for live trace capture.\n- Use `agents.deploy` to create/update the Databricks Model Serving endpoint backed by the UC model version.\n- Wait until the serving endpoint reports READY before continuing to downstream steps.\n- Pass the prod experiment ID via environment variables so inference traces are logged automatically.", - "id": "me3m6ovfqkd" + "Create/update the Databricks App, grant the app service principal UC and Gateway access, deploy source, and register the app in uc_state." + ] }, { "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "import mlflow\n", - "\n", - "# Create prod experiment for production inference traces\n", - "# Use shared path for job compatibility and visibility\n", - "prod_experiment_name = f\"/Shared/{CATALOG}_complaint_agent_prod\"\n", - "\n", - "# set_experiment creates the experiment if it doesn't exist, or activates it if it does\n", - "prod_experiment = mlflow.set_experiment(prod_experiment_name)\n", - "prod_experiment_id = prod_experiment.experiment_id\n", - "print(f\"✅ Using prod experiment: {prod_experiment_name} (ID: {prod_experiment_id})\")\n", - "\n", - "# Add experiment to UC state for cleanup\n", + "import os\n", "import sys\n", + "import time\n", + "\n", "sys.path.append('../utils')\n", + "from agent_app_client import gateway_chat_probe\n", "from uc_state import add\n", "\n", - "experiment_data = {\n", - " \"experiment_id\": prod_experiment_id,\n", - " \"name\": prod_experiment_name\n", - "}\n", - "add(CATALOG, \"experiments\", experiment_data)\n", - "print(f\"✅ Added prod experiment to UC state\")" - ], - "execution_count": null, - "outputs": [], - "id": "qjnjjztbbna" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "from datetime import timedelta\n", + "from databricks.sdk import WorkspaceClient\n", + "from databricks.sdk.service import catalog as catalog_svc\n", + "from databricks.sdk.service.apps import App, AppDeployment\n", + "from databricks.sdk.service.serving import (\n", + " ServingEndpointAccessControlRequest,\n", + " ServingEndpointPermissionLevel,\n", + ")\n", "\n", - "from databricks import agents\n", + "w = WorkspaceClient()\n", + "source_code_path = os.path.abspath(\"../apps/complaint-agent\")\n", + "print(f\"App name: {APP_NAME}\")\n", + "print(f\"App source: {source_code_path}\")\n", + "\n", + "gateway_chat_probe(llm_model=LLM_MODEL, w=w, dbutils=dbutils)\n", + "print(f\"Verified {LLM_MODEL} is queryable through Unity AI Gateway\")\n", + "\n", + "app_yaml_path = os.path.join(source_code_path, \"app.yaml\")\n", + "app_yaml_contents = f\"\"\"command:\n", + " - python\n", + " - start_server.py\n", + "env:\n", + " - name: DATABRICKS_CATALOG\n", + " value: '{CATALOG}'\n", + " - name: LLM_MODEL\n", + " value: '{LLM_MODEL}'\n", + " - name: MLFLOW_EXPERIMENT_ID\n", + " value: '{prod_experiment_id}'\n", + " - name: MLFLOW_TRACKING_URI\n", + " value: 'databricks'\n", + " - name: MLFLOW_REGISTRY_URI\n", + " value: 'databricks-uc'\n", + "\"\"\"\n", + "with open(app_yaml_path, \"w\") as f:\n", + " f.write(app_yaml_contents)\n", + "print(f\"Wrote app runtime config: {app_yaml_path}\")\n", + "\n", + "app_def = App(\n", + " name=APP_NAME,\n", + " description=\"Casper's complaint triage agent served by MLflow AgentServer on Databricks Apps.\",\n", + " default_source_code_path=source_code_path,\n", + ")\n", + "try:\n", + " w.apps.get(APP_NAME)\n", + " print(f\"App {APP_NAME} exists, updating...\")\n", + " w.apps.update(APP_NAME, app_def)\n", + "except Exception:\n", + " print(f\"Creating app {APP_NAME}...\")\n", + " w.apps.create(app_def)\n", + "\n", + "\n", + "def _app_state(a):\n", + " cs = getattr(a, \"compute_status\", None)\n", + " s = getattr(cs, \"state\", None) if cs is not None else None\n", + " if s is None:\n", + " s = getattr(a, \"state\", None)\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", + "\n", + "\n", + "deadline = time.time() + 30 * 60\n", + "while True:\n", + " current = w.apps.get(APP_NAME)\n", + " state = _app_state(current)\n", + " print(f\"App {APP_NAME} state: {state}\")\n", + " if state in (\"ACTIVE\", \"RUNNING\", \"READY\"):\n", + " app_status = current\n", + " break\n", + " if state in (\"ERROR\", \"FAILED\"):\n", + " raise RuntimeError(f\"App {APP_NAME} entered failure state: {state}\")\n", + " if time.time() > deadline:\n", + " raise TimeoutError(f\"App {APP_NAME} not ready after 30 minutes (last state: {state})\")\n", + " time.sleep(15)\n", + "\n", + "app_sp_id = (\n", + " getattr(app_status, \"service_principal_client_id\", None)\n", + " or (app_status.as_dict() if hasattr(app_status, \"as_dict\") else {}).get(\"service_principal_client_id\")\n", + ")\n", + "app_uc_principal = (\n", + " getattr(app_status, \"id\", None)\n", + " or app_sp_id\n", + " or (app_status.as_dict() if hasattr(app_status, \"as_dict\") else {}).get(\"id\")\n", + ")\n", + "assert app_sp_id, \"Could not determine app service principal client ID\"\n", + "assert app_uc_principal, \"Could not determine app UC principal\"\n", + "print(f\"App SP ID: {app_sp_id}\")\n", + "\n", + "for full_name, securable_type, privilege in [\n", + " (f\"{CATALOG}\", \"CATALOG\", catalog_svc.Privilege.USE_CATALOG),\n", + " (f\"{CATALOG}.ai\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", + " (f\"{CATALOG}.prompts\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", + " (f\"{CATALOG}.ai.get_order_overview\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + " (f\"{CATALOG}.ai.get_order_timing\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + " (f\"{CATALOG}.ai.get_location_timings\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + "]:\n", + " try:\n", + " w.grants.update(\n", + " full_name=full_name,\n", + " securable_type=securable_type,\n", + " changes=[\n", + " catalog_svc.PermissionsChange(\n", + " add=[privilege],\n", + " principal=app_uc_principal,\n", + " )\n", + " ],\n", + " )\n", + " print(f\"Granted {privilege} on {securable_type} {full_name}\")\n", + " except Exception as e:\n", + " print(f\"Could not grant {privilege} on {full_name} to {app_uc_principal}: {e}\")\n", + "\n", + "# LLM_MODEL names a Unity AI Gateway-backed endpoint. We only use the serving\n", + "# endpoint permissions API to grant CAN_QUERY on that Gateway endpoint; no LLM\n", + "# request is routed through legacy model-serving invocation routes.\n", + "llm_endpoint = None\n", + "try:\n", + " llm_endpoint = w.serving_endpoints.get(LLM_MODEL)\n", + "except Exception:\n", + " matches = [ep for ep in w.serving_endpoints.list() if ep.name == LLM_MODEL]\n", + " if matches:\n", + " llm_endpoint = matches[0]\n", + "if llm_endpoint is None or not getattr(llm_endpoint, \"id\", None):\n", + " raise RuntimeError(f\"Could not resolve Gateway endpoint {LLM_MODEL} for permission grant\")\n", + "\n", + "w.serving_endpoints.update_permissions(\n", + " serving_endpoint_id=llm_endpoint.id,\n", + " access_control_list=[\n", + " ServingEndpointAccessControlRequest(\n", + " service_principal_name=app_sp_id,\n", + " permission_level=ServingEndpointPermissionLevel.CAN_QUERY,\n", + " )\n", + " ],\n", + ")\n", + "print(f\"Granted CAN_QUERY on Gateway endpoint {LLM_MODEL} to app SP {app_sp_id}\")\n", "\n", - "if _reuse_endpoint:\n", - " deployment_info = None\n", - " print(f\"\\u2705 Endpoint {endpoint_name} is READY (reused from previous deploy)\")\n", - "else:\n", - " deployment_info = agents.deploy(\n", - " model_name=UC_MODEL_NAME,\n", - " model_version=uc_registered_model_info.version,\n", - " scale_to_zero=False,\n", - " endpoint_name=endpoint_name,\n", - " environment_vars={\"MLFLOW_EXPERIMENT_ID\": str(prod_experiment_id)},\n", + "try:\n", + " w.api_client.do(\n", + " \"PATCH\",\n", + " f\"/api/2.0/permissions/apps/{APP_NAME}\",\n", + " body={\"access_control_list\": [{\"group_name\": \"account users\", \"permission_level\": \"CAN_USE\"}]},\n", " )\n", + " print(\"Granted CAN_USE on app to account users for notebook/job smoke tests\")\n", + "except Exception as e:\n", + " print(f\"Could not grant account users CAN_USE on app {APP_NAME}: {e}\")\n", + "\n", + "add(CATALOG, \"apps\", {\n", + " \"name\": APP_NAME,\n", + " \"url\": getattr(app_status, \"url\", \"\"),\n", + " \"service_principal_client_id\": app_sp_id,\n", + " \"oauth2_app_client_id\": getattr(app_status, \"oauth2_app_client_id\", \"\"),\n", + " \"agent\": 'complaint',\n", + "})\n", + "print(\"Registered app in UC state\")\n", + "\n", + "deployment = w.apps.deploy(\n", + " app_name=app_status.name,\n", + " app_deployment=AppDeployment(source_code_path=source_code_path),\n", + ")\n", "\n", - " workspace = WorkspaceClient()\n", - " ready_endpoint = workspace.serving_endpoints.wait_get_serving_endpoint_not_updating(\n", - " name=endpoint_name,\n", - " timeout=timedelta(minutes=30),\n", - " )\n", "\n", - " if ready_endpoint.state.ready != EndpointStateReady.READY:\n", - " raise RuntimeError(\n", - " f\"Endpoint {endpoint_name} is {ready_endpoint.state.ready} after deployment; retry or investigate.\"\n", - " )\n", + "def _deploy_state(d):\n", + " st = getattr(d, \"status\", None)\n", + " s = getattr(st, \"state\", None) if st is not None else None\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", "\n", - " print(f\"\\u2705 Endpoint {endpoint_name} is READY\")\n", - "" - ], - "execution_count": null, - "outputs": [], - "id": "exckljo2zx4" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "# === Grant UC perms to the endpoint's runtime System Service Principal ===\n", - "#\n", - "# Model serving endpoints run inference as a workspace-level SCIM SP whose\n", - "# displayName is \"System Service Principal\". These SPs are NOT members of\n", - "# `account users`, so grants made above to `account users` do NOT apply to\n", - "# them — every fresh endpoint would fail on its first tool call with\n", - "# PERMISSION_DENIED (\"USE CATALOG\" / \"USE SCHEMA\" / \"EXECUTE\") until\n", - "# permissions were granted manually.\n", - "#\n", - "# See utils/agent_runtime_grants.py for the full rationale. Grants are\n", - "# idempotent. Skipped on the reuse path (no fresh SP to grant to).\n", - "if deployment_info is not None:\n", - " import sys\n", - " sys.path.append('../utils')\n", - " from agent_runtime_grants import grant_agent_runtime_perms\n", "\n", - " # Pass endpoint_name so the helper also grants to the endpoint\n", - " # creator — that's the actual runtime identity in EMBEDDED_CREDENTIALS\n", - " # mode workspaces (where 'System Service Principal' isn't created and\n", - " # `account users` may be empty at the workspace level).\n", - " grant_agent_runtime_perms(\n", - " spark,\n", - " CATALOG,\n", - " workspace_client=workspace,\n", - " endpoint_name=endpoint_name,\n", - " )\n", - "else:\n", - " print(\"♻ Endpoint reused — skipping runtime SP grants (already applied on first deploy).\")\n", - "" - ], - "execution_count": null, - "outputs": [], - "id": "00fc632e" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "print(deployment_info)" - ], - "execution_count": null, - "outputs": [], - "id": "i1syfkwcivl" + "deadline = time.time() + 30 * 60\n", + "while True:\n", + " current_dep = w.apps.get_deployment(app_name=app_status.name, deployment_id=deployment.deployment_id)\n", + " state = _deploy_state(current_dep)\n", + " print(f\"Deployment state: {state}\")\n", + " if state == \"SUCCEEDED\":\n", + " deployment_status = current_dep\n", + " break\n", + " if state in (\"FAILED\", \"STOPPED\"):\n", + " raise RuntimeError(f\"Deployment failed for {app_status.name}: state={state}\")\n", + " if time.time() > deadline:\n", + " raise TimeoutError(f\"Deployment for {app_status.name} not ready after 30 minutes (last state: {state})\")\n", + " time.sleep(10)\n", + "\n", + "print(f\"Complaint agent app deployed: {getattr(app_status, 'url', '')}\")\n", + "display(deployment_status)\n" + ] }, { "cell_type": "markdown", "metadata": {}, - "source": "#### Record Model in State\n\n- Store the deployment metadata with `uc_state.add` to facilitate cleanup in the future.", - "id": "4i9yj8vjs2" - }, - { - "cell_type": "code", - "metadata": {}, "source": [ - "# Also add to UC-state — but only when we actually deployed a new endpoint.\n", - "# On the reuse path (_reuse_endpoint) the endpoint was already registered by a\n", - "# previous run, so deployment_info is None and re-adding would either crash\n", - "# uc_state.add() or persist a null row that cleanup can't resolve. Mirror the\n", - "# guard used in refunder_agent for symmetry.\n", - "import sys\n", - "sys.path.append('../utils')\n", - "from uc_state import add\n", + "#### Production Monitoring\n", "\n", - "if deployment_info is not None:\n", - " add(dbutils.widgets.get(\"CATALOG\"), \"endpoints\", deployment_info)\n", - "else:\n", - " print(\"✅ Endpoint already tracked in uc_state from a previous deploy; skipping add.\")" + "- Register MLflow guideline scorers to monitor decision quality and refund reasoning on production traffic.\n", + "- Enable 10% sampling to flag decision drift or policy regressions without impacting performance." ], - "execution_count": null, - "outputs": [], - "id": "ihnhnv5plw" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": "#### Production Monitoring\n\n- Register MLflow guideline scorers to monitor decision quality and refund reasoning on production traffic.\n- Enable 10% sampling to flag decision drift or policy regressions without impacting performance.", "id": "bhjj7zuan9g" }, { @@ -1026,7 +685,7 @@ " sample_rate=1.0,\n", ")\n", "\n", - "# Domain — complaint-specific decision and refund-reason guidelines.\n", + "# Domain \u2014 complaint-specific decision and refund-reason guidelines.\n", "decision_quality_monitor = _register_scorer(\n", " Guidelines(\n", " name=\"decision_quality_prod\",\n", @@ -1049,137 +708,11 @@ " name=f\"{UC_MODEL_NAME}_refund_reason\",\n", ")\n", "\n", - "print(\"\\u2705 Production monitoring enabled — 5 scorers active at 100% sampling\")" + "print(\"\\u2705 Production monitoring enabled \u2014 5 scorers active at 100% sampling\")" ], "execution_count": null, "outputs": [], "id": "v1j0a0y21ct" - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Prompt Registry\n", - "\n", - "Register the `ComplaintTriage` signature docstring (the actual instructions the\n", - "DSPy ReAct agent operates on) under `{CATALOG}.prompts.complaint_system` so it\n", - "lives alongside the deployed model in Unity Catalog.\n", - "\n", - "**Important**: unlike the Refund agent (which calls `mlflow.genai.load_prompt`\n", - "at endpoint startup and so can hot-swap), the Complaint agent's prompt is\n", - "baked into the DSPy Signature class at model-log time. Registry entries here\n", - "are **audit / governance**, not runtime hot-swap. Changing the production\n", - "prompt requires a redeploy that bakes the new docstring into a new model\n", - "version.\n", - "\n", - "The registry is seeded with **three versions** on first deploy\n", - "(v1 → v2 → v3) so the MLflow Prompt Registry UI shows a meaningful version\n", - "history. The seeded versions are tagged `is_demo_seed=\"true\"` to make it\n", - "clear they're synthetic baselines, not a real engineering changelog." - ], - "id": "60e5b88f" - }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "import re\n", - "import sys\n", - "\n", - "sys.path.append('../utils')\n", - "from prompt_registry import seed_prompt_history\n", - "\n", - "# Use the absolute path captured by %%writefilev (see cell 14) so this works\n", - "# even when CWD has drifted between cells (notebooks on serverless / shared\n", - "# clusters don't have a stable CWD across %pip installs and kernel restarts).\n", - "_agent_py_path = _WRITEFILEV_ABS_PATHS.get(\"agent.py\", \"agent.py\")\n", - "with open(_agent_py_path) as f:\n", - " _agent_py = f.read()\n", - "\n", - "# Extract the ComplaintTriage signature docstring — that's the prompt content\n", - "# DSPy ReAct uses to drive the agent. NOTE: DSPy adds its own ReAct scaffold\n", - "# around this at runtime; the registry entry captures only the domain\n", - "# instructions, which is the part a human would actually iterate on.\n", - "_match = re.search(\n", - " r'class ComplaintTriage\\(dspy\\.Signature\\):\\s*\"\"\"(.*?)\"\"\"',\n", - " _agent_py,\n", - " re.DOTALL,\n", - ")\n", - "if not _match:\n", - " raise RuntimeError(\n", - " \"Could not extract ComplaintTriage docstring from agent.py. \"\n", - " \"If the Signature class was renamed, update this regex.\"\n", - " )\n", - "_current_complaint_prompt = _match.group(1).strip()\n", - "\n", - "# Two earlier versions, seeded on first deploy so the Prompt Registry UI shows\n", - "# v1 → v2 → v3 history. These are demo seeds, NOT a real engineering changelog\n", - "# — seed_prompt_history tags each with is_demo_seed=\"true\" so anyone auditing\n", - "# the registry can tell synthetic baselines apart from real iterations.\n", - "_COMPLAINT_V1 = (\n", - " \"Analyze customer complaints and recommend a triage action.\\n\\n\"\n", - " \"For each complaint:\\n\"\n", - " \"- Extract the order_id\\n\"\n", - " \"- Choose a decision: suggest_credit OR escalate\\n\"\n", - " \"- Provide a rationale\\n\\n\"\n", - " \"Output: order_id, decision, rationale.\"\n", - ")\n", - "_COMPLAINT_V2 = (\n", - " \"Analyze customer complaints for Casper's Kitchens and recommend triage actions.\\n\\n\"\n", - " \"Process:\\n\"\n", - " \"1. Extract order_id from complaint\\n\"\n", - " \"2. Use get_order_overview(order_id) for order details and items\\n\"\n", - " \"3. Make a decision: suggest_credit (with credit_amount) or escalate (with priority)\\n\\n\"\n", - " \"Decision Framework:\\n\\n\"\n", - " \"SUGGEST_CREDIT:\\n\"\n", - " \"- Delivery delays: 10-20% of order total\\n\"\n", - " \"- Missing items: estimate $10 per item\\n\"\n", - " \"- Food quality: 20-40% of order total based on severity\\n\\n\"\n", - " \"ESCALATE:\\n\"\n", - " \"- priority=\\\"standard\\\": vague complaints, billing issues, service complaints\\n\"\n", - " \"- priority=\\\"urgent\\\": legal threats, health/safety concerns, abusive language\\n\\n\"\n", - " \"Output: order_id, complaint_category, decision, credit_amount, confidence, priority, rationale.\"\n", - ")\n", - "\n", - "_uc_version = (\n", - " uc_registered_model_info.version\n", - " if \"uc_registered_model_info\" in dir() and uc_registered_model_info is not None\n", - " else \"reused-endpoint\"\n", - ")\n", - "\n", - "_common_tags = {\n", - " \"agent\": \"complaint\",\n", - " \"stage\": \"complaint_agent\",\n", - " \"uc_model\": UC_MODEL_NAME,\n", - " \"consumed_via\": \"baked into DSPy Signature docstring at model-log time (no runtime hot-swap)\",\n", - "}\n", - "\n", - "seed_prompt_history(\n", - " spark=spark,\n", - " catalog=CATALOG,\n", - " name=\"complaint_system\",\n", - " historical=[\n", - " {\n", - " \"template\": _COMPLAINT_V1,\n", - " \"commit_message\": \"v1: minimal triage — decision + rationale only, no tool use (demo history seed)\",\n", - " \"tags\": _common_tags,\n", - " },\n", - " {\n", - " \"template\": _COMPLAINT_V2,\n", - " \"commit_message\": \"v2: added decision framework with categories + flat credit heuristics (demo history seed)\",\n", - " \"tags\": _common_tags,\n", - " },\n", - " ],\n", - " current={\n", - " \"template\": _current_complaint_prompt,\n", - " \"commit_message\": f\"v3 (production): percentile-based credit calc + get_order_timing/get_location_timings tools — UC model version {_uc_version}\",\n", - " \"tags\": {**_common_tags, \"uc_model_version\": str(_uc_version)},\n", - " },\n", - ")" - ], - "execution_count": null, - "outputs": [], - "id": "9f25ed8c" } ], "metadata": { diff --git a/stages/complaint_agent_stream.ipynb b/stages/complaint_agent_stream.ipynb index 8db44d8..c074f3d 100644 --- a/stages/complaint_agent_stream.ipynb +++ b/stages/complaint_agent_stream.ipynb @@ -25,10 +25,15 @@ "from databricks.sdk import WorkspaceClient\n", "import databricks.sdk.service.jobs as j\n", "import os\n", + "import sys\n", + "\n", + "sys.path.append('../utils')\n", + "from agent_app_client import complaint_agent_app_name\n", "\n", "w = WorkspaceClient()\n", "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", + "COMPLAINT_AGENT_APP_NAME = complaint_agent_app_name(CATALOG)\n", "\n", "notebook_abs_path = os.path.abspath(\"../jobs/complaint_agent_stream\")\n", "notebook_dbx_path = notebook_abs_path.replace(\n", @@ -38,11 +43,8 @@ "\n", "job_name = f\"Complaint Agent Stream ({CATALOG})\"\n", "\n", - "# timeout_seconds=600 (10 min, matches cron) so a single hung UDF call against\n", - "# a stalled complaint-agent endpoint cannot block the whole queue forever.\n", - "# Without this, an orphaned run from an earlier test session can hold the queue\n", - "# slot indefinitely while every subsequent cron tick piles up as QUEUED (we\n", - "# observed 23 queued runs behind one 3h47m-hung run before adding this).\n", + "# timeout_seconds=600 (10 min, matches cron) so a single hung complaint-agent\n", + "# app call cannot block the whole queue forever.\n", "task_def = [\n", " j.Task(\n", " task_key=\"complaint_agent_stream\",\n", @@ -51,7 +53,7 @@ " notebook_path=notebook_dbx_path,\n", " base_parameters={\n", " \"CATALOG\": CATALOG,\n", - " \"COMPLAINT_AGENT_ENDPOINT_NAME\": dbutils.widgets.get(\"COMPLAINT_AGENT_ENDPOINT_NAME\"),\n", + " \"COMPLAINT_AGENT_APP_NAME\": COMPLAINT_AGENT_APP_NAME,\n", " },\n", " )\n", " )\n", @@ -63,9 +65,9 @@ ")\n", "\n", "# queue.enabled=False: drop cron triggers if a previous run is still active\n", - "# instead of stacking them up. For an availableNow catch-up stream, dropping\n", + "# instead of stacking them up. For an availableNow catch-up stream, dropping\n", "# is correct: the NEXT tick will pick up whatever rows the previous run didn't\n", - "# get to. Default (enabled=True) creates an unbounded backlog on any hang.\n", + "# get to.\n", "queue_def = j.QueueSettings(enabled=False)\n", "\n", "existing = [jb for jb in w.jobs.list(name=job_name) if jb.settings.name == job_name]\n", @@ -74,7 +76,7 @@ " w.jobs.reset(job_id=job_id, new_settings=j.JobSettings(\n", " name=job_name, tasks=task_def, schedule=schedule_def, queue=queue_def,\n", " ))\n", - " print(f\"\\u267b\\ufe0f Updated existing job_id={job_id}\")\n", + " print(f\"Updated existing job_id={job_id}\")\n", "else:\n", " job = w.jobs.create(name=job_name, tasks=task_def, schedule=schedule_def, queue=queue_def)\n", " job_id = job.job_id\n", @@ -82,10 +84,10 @@ " sys.path.append('../utils')\n", " from uc_state import add\n", " add(CATALOG, \"jobs\", job)\n", - " print(f\"\\u2705 Created job_id={job_id}\")\n", + " print(f\"Created job_id={job_id}\")\n", "\n", "w.jobs.run_now(job_id=job_id)\n", - "print(f\"\\U0001f680 Started run of {job_name}\")" + "print(f\"Started run of {job_name} against app {COMPLAINT_AGENT_APP_NAME}\")\n" ], "execution_count": null, "outputs": [] diff --git a/stages/complaint_evaluation.ipynb b/stages/complaint_evaluation.ipynb index 1216cea..a8bdec8 100644 --- a/stages/complaint_evaluation.ipynb +++ b/stages/complaint_evaluation.ipynb @@ -4,29 +4,25 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Complaint Agent — Evaluation\n", + "### Complaint Agent \u2014 Evaluation\n", "\n", - "Standalone evaluation task for the deployed Complaint Agent. Runs after\n", - "`Complaint_Triage_Agent` in the pipeline so the endpoint is up before\n", - "we hit it.\n", + "Standalone evaluation task for the deployed Complaint Agent App. Runs after\n", + "`Complaint_Triage_Agent` in the pipeline so the app is ready before we hit it.\n", "\n", - "- **Gated by `SKIP_EVAL`** (default `\"true\"`) — flip to `\"false\"` to actually\n", + "- **Gated by `SKIP_EVAL`** (default `\"true\"`) \u2014 flip to `\"false\"` to actually\n", " run the evaluation. The default skip is a deliberate rate-limit safeguard:\n", - " the complaint agent uses DSPy ReAct, which fires 5–10 LLM calls per\n", - " invocation, so even a small eval set can blow past the workspace QPS\n", - " limit on the shared judge endpoint.\n", - "- Calls the deployed endpoint via `mlflow.deployments.get_deploy_client`\n", - " rather than importing the agent module — so this notebook is fully\n", - " self-contained and tests the actual production endpoint shape.\n", - "- Eval results land in `/Shared/{CATALOG}_complaint_agent_dev`." + " the complaint agent uses DSPy ReAct, which fires 5-10 LLM calls per invocation.\n", + "- Calls the deployed Databricks App via its MLflow AgentServer `/responses`\n", + " contract rather than importing the agent module.\n", + "- Eval results land in `/Shared/{CATALOG}_complaint_agent_dev`.\n" ] }, { "cell_type": "code", "metadata": {}, "source": [ - "%pip install -U -qqqq mlflow-skinny[databricks] databricks-sdk\n", - "dbutils.library.restartPython()" + "%pip install -U -qqqq mlflow-skinny[databricks] databricks-sdk requests\n", + "dbutils.library.restartPython()\n" ], "execution_count": null, "outputs": [] @@ -36,7 +32,13 @@ "metadata": {}, "source": [ "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "ENDPOINT_NAME = dbutils.widgets.get(\"COMPLAINT_AGENT_ENDPOINT_NAME\")\n", + "\n", + "import os\n", + "import sys\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import complaint_agent_app_name\n", + "\n", + "APP_NAME = complaint_agent_app_name(CATALOG)\n", "\n", "try:\n", " SKIP_EVAL = dbutils.widgets.get(\"SKIP_EVAL\").strip().lower() == \"true\"\n", @@ -49,10 +51,10 @@ "mlflow.set_experiment(DEV_EXPERIMENT)\n", "\n", "print(f\"Catalog: {CATALOG}\")\n", - "print(f\"Endpoint: {ENDPOINT_NAME}\")\n", + "print(f\"App: {APP_NAME}\")\n", "print(f\"SKIP_EVAL: {SKIP_EVAL}\")\n", "print(f\"Dev experiment: {DEV_EXPERIMENT}\")\n", - "print(f\"MLflow version: {mlflow.__version__}\")" + "print(f\"MLflow version: {mlflow.__version__}\")\n" ], "execution_count": null, "outputs": [] @@ -63,7 +65,7 @@ "source": [ "if SKIP_EVAL:\n", " print(\n", - " f\"⏭ SKIP_EVAL=true — skipping mlflow.genai.evaluate to avoid the ~50 \"\n", + " f\"\u23ed SKIP_EVAL=true \u2014 skipping mlflow.genai.evaluate to avoid the ~50 \"\n", " f\"LM-call eval burst (DSPy ReAct fires 5-10 LLM calls per agent invocation). \"\n", " f\"Pass --params \\\"SKIP_EVAL=false\\\" to actually run the evaluation.\"\n", " )\n", @@ -84,25 +86,34 @@ "POLL_INTERVAL_S = 15\n", "MAX_POLLS = 40\n", "\n", - "print(f\"Polling endpoint readiness ({MAX_POLLS * POLL_INTERVAL_S // 60} min max)…\")\n", + "print(f\"Polling app readiness ({MAX_POLLS * POLL_INTERVAL_S // 60} min max)...\")\n", + "\n", + "\n", + "def _app_state(a):\n", + " cs = getattr(a, \"compute_status\", None)\n", + " s = getattr(cs, \"state\", None) if cs is not None else None\n", + " if s is None:\n", + " s = getattr(a, \"state\", None)\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", + "\n", "\n", "for attempt in range(1, MAX_POLLS + 1):\n", " try:\n", - " ep = w.serving_endpoints.get(ENDPOINT_NAME)\n", - " ready = str(getattr(ep.state, \"ready\", \"\")).upper() if ep.state else \"\"\n", - " cfg_update = str(getattr(ep.state, \"config_update\", \"\")).upper() if ep.state else \"\"\n", - " if \"READY\" in ready:\n", - " print(f\" ✅ Endpoint READY (config_update={cfg_update or 'n/a'})\")\n", + " app = w.apps.get(APP_NAME)\n", + " state = _app_state(app)\n", + " url = getattr(app, \"url\", \"\")\n", + " if state in (\"ACTIVE\", \"RUNNING\", \"READY\") and url:\n", + " print(f\" App ready: state={state}, url={url}\")\n", " break\n", - " print(f\" ⏳ [{attempt}/{MAX_POLLS}] ready={ready}, config_update={cfg_update}\")\n", + " print(f\" [{attempt}/{MAX_POLLS}] state={state or 'unknown'}, url={url or 'pending'}\")\n", " except Exception as e:\n", - " print(f\" ⚠️ poll error: {type(e).__name__}: {e}\")\n", + " print(f\" poll error: {type(e).__name__}: {e}\")\n", " time.sleep(POLL_INTERVAL_S)\n", "else:\n", " raise RuntimeError(\n", - " f\"Endpoint {ENDPOINT_NAME} did not become READY within \"\n", + " f\"App {APP_NAME} did not become ready within \"\n", " f\"{MAX_POLLS * POLL_INTERVAL_S // 60} minutes.\"\n", - " )" + " )\n" ], "execution_count": null, "outputs": [] @@ -113,19 +124,15 @@ "source": [ "import os\n", "import random\n", - "import json\n", - "from mlflow.deployments import get_deploy_client\n", + "import time\n", + "\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import call_agent_app_text\n", "\n", "os.environ[\"MLFLOW_GENAI_EVAL_MAX_WORKERS\"] = \"1\"\n", "os.environ[\"MLFLOW_GENAI_EVAL_MAX_SCORER_WORKERS\"] = \"1\"\n", - "# Bump client-side HTTP timeouts. Defaults (120s) trip up agent endpoints and\n", - "# the LLM judge calls inside scorers (Safety / RelevanceToQuery / Guidelines).\n", - "os.environ.setdefault(\"MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT\", \"600\")\n", - "os.environ.setdefault(\"MLFLOW_DEPLOYMENT_PREDICT_TOTAL_TIMEOUT\", \"900\")\n", "os.environ.setdefault(\"MLFLOW_HTTP_REQUEST_TIMEOUT\", \"600\")\n", "\n", - "_deploy_client = get_deploy_client(\"databricks\")\n", - "\n", "\n", "def _is_rate_limit(exc: BaseException) -> bool:\n", " msg = str(exc)\n", @@ -136,60 +143,33 @@ " )\n", "\n", "\n", - "def _extract_text(response) -> str:\n", - " \"\"\"Pull the assistant text out of a ResponsesAgent endpoint response.\"\"\"\n", - " if isinstance(response, str):\n", - " try:\n", - " response = json.loads(response)\n", - " except (json.JSONDecodeError, ValueError):\n", - " return response\n", - " if not isinstance(response, dict):\n", - " return str(response)\n", - "\n", - " preds = response.get(\"predictions\")\n", - " if preds is not None:\n", - " if isinstance(preds, list) and preds:\n", - " return _extract_text(preds[0])\n", - " if isinstance(preds, (dict, str)):\n", - " return _extract_text(preds)\n", - "\n", - " output = response.get(\"output\", [])\n", - " if isinstance(output, list):\n", - " for item in output:\n", - " if not isinstance(item, dict):\n", - " continue\n", - " content = item.get(\"content\")\n", - " if isinstance(content, str):\n", - " return content\n", - " if isinstance(content, list) and content:\n", - " first = content[0]\n", - " if isinstance(first, dict):\n", - " return first.get(\"text\") or first.get(\"value\") or str(first)\n", - " return str(first)\n", - " if isinstance(output, str):\n", - " return output\n", - " return str(response)\n", + "def _input_messages(value):\n", + " if isinstance(value, dict):\n", + " return value.get(\"input\") or value.get(\"messages\") or [value]\n", + " return value\n", "\n", "\n", "def predict_fn(input):\n", - " \"\"\"Call the deployed Complaint Agent endpoint with retry-on-rate-limit.\"\"\"\n", + " \"\"\"Call the deployed Complaint Agent App with retry-on-rate-limit.\"\"\"\n", " max_attempts = 6\n", + " input_messages = _input_messages(input)\n", " for attempt in range(max_attempts):\n", " try:\n", - " response = _deploy_client.predict(\n", - " endpoint=ENDPOINT_NAME,\n", - " inputs={\"input\": input},\n", + " return call_agent_app_text(\n", + " app_name=APP_NAME,\n", + " input_messages=input_messages,\n", + " dbutils=dbutils,\n", + " timeout=600,\n", " )\n", - " return _extract_text(response)\n", " except Exception as exc:\n", " if not _is_rate_limit(exc) or attempt == max_attempts - 1:\n", " raise\n", " backoff = min(60, 2 ** attempt) + random.uniform(0, 1)\n", - " print(f\" ⚠️ rate-limited (attempt {attempt + 1}/{max_attempts}), sleeping {backoff:.1f}s\")\n", + " print(f\" rate-limited (attempt {attempt + 1}/{max_attempts}), sleeping {backoff:.1f}s\")\n", " time.sleep(backoff)\n", "\n", "\n", - "print(f\"✅ predict_fn defined — will call endpoint {ENDPOINT_NAME}\")" + "print(f\"predict_fn defined \u2014 will call app {APP_NAME}\")\n" ], "execution_count": null, "outputs": [] @@ -257,7 +237,7 @@ "except ImportError:\n", " RelevanceToQuery = None\n", " _has_relevance = False\n", - " print(\"⚠\\ufe0f RelevanceToQuery not available in this MLflow version — skipping it from the eval mix\")\n", + " print(\"\u26a0\\ufe0f RelevanceToQuery not available in this MLflow version \u2014 skipping it from the eval mix\")\n", "\n", "evidence_reasoning = Guidelines(\n", " name=\"evidence_reasoning\",\n", @@ -295,7 +275,7 @@ "if _has_relevance:\n", " _DATASET_SCORERS.append(RelevanceToQuery())\n", "\n", - "print(f\"✅ Scorers defined ({len(_DATASET_SCORERS)}): \" + \", \".join(s.name if hasattr(s, 'name') else type(s).__name__ for s in _DATASET_SCORERS))" + "print(f\"\u2705 Scorers defined ({len(_DATASET_SCORERS)}): \" + \", \".join(s.name if hasattr(s, 'name') else type(s).__name__ for s in _DATASET_SCORERS))" ], "execution_count": null, "outputs": [] @@ -322,7 +302,7 @@ " predict_fn=predict_fn,\n", " )\n", "\n", - "print(\"\\n✅ Evaluation complete\")\n", + "print(\"\\n\u2705 Evaluation complete\")\n", "print(f\" metrics: {getattr(results, 'metrics', 'N/A')}\")" ], "execution_count": null, @@ -332,14 +312,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Trace-derived eval — recent production calls\n", + "## Trace-derived eval \u2014 recent production calls\n", "\n", "The block above runs curated complaint scenarios. That's good for catching\n", "regressions on **known-good cases**, but it doesn't tell us how the agent\n", "is doing on **real production traffic**.\n", "\n", "This section harvests the most recent successful traces from the prod\n", - "experiment (`/Shared/{CATALOG}_complaint_agent_prod` — populated whenever\n", + "experiment (`/Shared/{CATALOG}_complaint_agent_prod` \u2014 populated whenever\n", "the streaming pipeline triggers a complaint), turns them back into eval\n", "inputs, and re-runs the agent against them with global quality guidelines.\n", "\n", @@ -347,7 +327,7 @@ "for the Operational Supervisor, so all three agents share the same flywheel:\n", "\n", "```\n", - "Production traffic → MLflow traces → resampled eval dataset → quality scoring\n", + "Production traffic \u2192 MLflow traces \u2192 resampled eval dataset \u2192 quality scoring\n", "```\n", "\n", "Skipped gracefully if the prod experiment doesn't exist yet or has no\n", @@ -373,13 +353,13 @@ "\n", "if _exp is None:\n", " print(\n", - " f\"Skipping trace eval — experiment {prod_experiment_name} not found.\\n\"\n", - " \" Send some complaints through the deployed endpoint first.\"\n", + " f\"Skipping trace eval \u2014 experiment {prod_experiment_name} not found.\\n\"\n", + " \" Send some complaints through the deployed app first.\"\n", " )\n", "else:\n", " print(f\"Found prod traces experiment: {prod_experiment_name} ({_exp.experiment_id})\")\n", "\n", - " # `return_type=\"list\"` is required — `mlflow.search_traces()` returns a\n", + " # `return_type=\"list\"` is required \u2014 `mlflow.search_traces()` returns a\n", " # pandas DataFrame by default in MLflow 3, and iterating it yields column\n", " # names (strings). The iteration below expects `Trace` objects with\n", " # `.data.spans` / `.info.request_id`, so we explicitly request the list form.\n", @@ -412,7 +392,7 @@ " \"complaint\": complaint,\n", " })\n", " except Exception as _te:\n", - " print(f\" ⚠️ skipped trace {getattr(trace.info, 'request_id', '?')}: {_te}\")\n", + " print(f\" \u26a0\ufe0f skipped trace {getattr(trace.info, 'request_id', '?')}: {_te}\")\n", " continue\n", "\n", " traces_df = pd.DataFrame(records)\n", @@ -469,10 +449,10 @@ " ],\n", " predict_fn=predict_fn,\n", " )\n", - " print(\"\\n✅ Trace-derived evaluation complete\")\n", + " print(\"\\n\u2705 Trace-derived evaluation complete\")\n", " print(f\" metrics: {getattr(trace_results, 'metrics', 'N/A')}\")\n", "else:\n", - " print(\"Skipping trace evaluation — no trace data available yet.\")" + " print(\"Skipping trace evaluation \u2014 no trace data available yet.\")" ], "execution_count": null, "outputs": [] diff --git a/stages/operational_app.ipynb b/stages/operational_app.ipynb index ebe6a77..1ae9f5e 100644 --- a/stages/operational_app.ipynb +++ b/stages/operational_app.ipynb @@ -53,6 +53,7 @@ "import sys, os\n", "sys.path.append('../utils')\n", "from uc_state import add\n", + "from agent_app_client import refund_agent_app_name, complaint_agent_app_name\n", "\n", "from databricks.sdk import WorkspaceClient\n", "from databricks.sdk.service.apps import (\n", @@ -64,7 +65,7 @@ "\n", "w = WorkspaceClient()\n", "\n", - "# Find the ops warehouse. DABs owns this resource — it's defined under\n", + "# Find the ops warehouse. DABs owns this resource \u2014 it's defined under\n", "# `resources.sql_warehouses.caspers_ops_warehouse` in databricks.yml and\n", "# created by `bundle deploy -t all`. This stage is now find-only: if the\n", "# warehouse is missing, deploy hasn't been run (or the resource was removed).\n", @@ -98,10 +99,10 @@ "APP_NAME = _re.sub(r\"-+\", \"-\", _re.sub(r\"[^a-z0-9-]\", \"-\", f\"ops-dashboard-{CATALOG}\".lower())).strip(\"-\")[:30]\n", "print(f\"App name: {APP_NAME}\")\n", "\n", - "# Lakebase Autoscale does not support AppResourceDatabase — the app connects\n", + "# Lakebase Autoscale does not support AppResourceDatabase \u2014 the app connects\n", "# directly using w.postgres.generate_database_credential() with the endpoint path.\n", "APP_DESCRIPTION = (\n", - " \"Casper's Ops Intelligence — chat with the Multi-Agent Supervisor \"\n", + " \"Casper's Ops Intelligence \u2014 chat with the Multi-Agent Supervisor \"\n", " \"across revenue, ops, food safety, legal & regulatory; review live \"\n", " \"complaint decisions and trigger refunds inline.\"\n", ")\n", @@ -153,7 +154,7 @@ "add(CATALOG, \"apps\", app_status)\n", "print(f\"\\u2705 App {APP_NAME} ready\")\n", "\n", - "# Resolve SP ID immediately — used in all subsequent permission cells.\n", + "# Resolve SP ID immediately \u2014 used in all subsequent permission cells.\n", "# service_principal_client_id is the documented OAuth UUID field.\n", "app_sp_id = (\n", " getattr(app_status, 'service_principal_client_id', None)\n", @@ -206,7 +207,7 @@ "# Pre-create the complaints schema and complaint_responses table so the SELECT\n", "# grant below always lands. Without this, Operational_App races with\n", "# Complaint_Agent_Stream (they have no depends_on relationship in databricks.yml,\n", - "# and the Complaint_Agent_Stream stage only kicks off a cron job and exits — the\n", + "# and the Complaint_Agent_Stream stage only kicks off a cron job and exits \u2014 the\n", "# job is what actually creates the schema/table on first run). When grants run\n", "# before the table exists, the per-grant try/except swallows the failure and the\n", "# Ops App's /api/complaint-decisions endpoint then fails with INSUFFICIENT_PERMISSIONS\n", @@ -234,7 +235,7 @@ " # `ai` for the app SP, the agent surfaces a PermissionError(\"EXECUTE on\n", " # Routine '.ai.get_order_details'\") on the very first tool call.\n", " # The `account users` grant the refunder_agent stage applies does NOT\n", - " # cover app SPs — workspace-local app SPs are not members of\n", + " # cover app SPs \u2014 workspace-local app SPs are not members of\n", " # `account users` (an account-level group), which is the regression that\n", " # broke this stage on 2026-06-08.\n", " (f\"{CATALOG}.ai\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", @@ -247,7 +248,7 @@ " (f\"{CATALOG}.lakeflow.gold_location_sales_hourly\", \"TABLE\", catalog_svc.Privilege.SELECT),\n", " # Materialized view backing the ops dashboard cancel-rate widget\n", " # (see /api/revenue in apps/caspers-ops-dashboard). Without this\n", - " # explicit grant the MV defaults to owner-only access — schema-level\n", + " # explicit grant the MV defaults to owner-only access \u2014 schema-level\n", " # USE_SCHEMA does NOT cascade SELECT to MVs in UC the way it does to\n", " # regular tables, so /api/revenue 503s on the very first page load.\n", " (f\"{CATALOG}.lakeflow.gold_location_order_status_daily\", \"TABLE\", catalog_svc.Privilege.SELECT),\n", @@ -262,7 +263,7 @@ " (f\"{CATALOG}.food_safety\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", " (f\"{CATALOG}.food_safety.inspections\", \"TABLE\", catalog_svc.Privilege.SELECT),\n", " (f\"{CATALOG}.food_safety.violations\", \"TABLE\", catalog_svc.Privilege.SELECT),\n", - " # Document schemas — needed so the app can list PDFs from UC volumes\n", + " # Document schemas \u2014 needed so the app can list PDFs from UC volumes\n", " (f\"{CATALOG}.legal_complaints\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", " (f\"{CATALOG}.regulatory\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", " (f\"{CATALOG}.audits\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", @@ -273,7 +274,7 @@ " (f\"{CATALOG}.consultancy\", \"SCHEMA\", catalog_svc.Privilege.READ_VOLUME),\n", " # food_safety has /Volumes/{CATALOG}/food_safety/reports referenced from the app\n", " (f\"{CATALOG}.food_safety\", \"SCHEMA\", catalog_svc.Privilege.READ_VOLUME),\n", - " # complaints schema — queried by /api/complaint-decisions in the app\n", + " # complaints schema \u2014 queried by /api/complaint-decisions in the app\n", " (f\"{CATALOG}.complaints\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", " (f\"{CATALOG}.complaints.complaint_responses\", \"TABLE\", catalog_svc.Privilege.SELECT),\n", "]:\n", @@ -292,7 +293,7 @@ " except Exception as e:\n", " print(f\"\\u26a0\\ufe0f Could not grant {privilege} on {full_name}: {e}\")\n", "\n", - "# ── Grant CAN_USE on the Genie warehouse to the app SP ──────────────────────\n", + "# \u2500\u2500 Grant CAN_USE on the Genie warehouse to the app SP \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Genie sub-agents called via the supervisor execute SQL on the warehouse\n", "# created by stages/genie_spaces.ipynb (`{CATALOG}-genie-warehouse`). Without\n", "# CAN_USE the warehouse refuses the query and the supervisor LLM surfaces it\n", @@ -301,7 +302,7 @@ " genie_wh_name = f\"{CATALOG}-genie-warehouse\"\n", " genie_wh = next((wh for wh in w.warehouses.list() if wh.name == genie_wh_name), None)\n", " if genie_wh is None:\n", - " print(f\"\\u26a0\\ufe0f Genie warehouse '{genie_wh_name}' not found — skipping CAN_USE grant\")\n", + " print(f\"\\u26a0\\ufe0f Genie warehouse '{genie_wh_name}' not found \u2014 skipping CAN_USE grant\")\n", " else:\n", " w.api_client.do(\n", " \"PATCH\",\n", @@ -329,8 +330,8 @@ "\n", "def _grant_can_query(endpoint_id: str, endpoint_name: str, sp_id: str, retries: int = 3) -> bool:\n", " \"\"\"Grant CAN_QUERY on a serving endpoint.\n", - " endpoint_id — the UUID returned by the list API (required by update_permissions)\n", - " endpoint_name — human-readable name used only for logging\n", + " endpoint_id \u2014 the UUID returned by the list API (required by update_permissions)\n", + " endpoint_name \u2014 human-readable name used only for logging\n", " \"\"\"\n", " for attempt in range(retries):\n", " try:\n", @@ -396,14 +397,14 @@ " print(f\"\\u274c FAILED to grant CAN_RUN on Genie space {space_id}: {e}\")\n", " return \"failed\"\n", "\n", - "# ── Resolve endpoints + Genie spaces created BY THIS CATALOG from uc_state ────\n", + "# \u2500\u2500 Resolve endpoints + Genie spaces created BY THIS CATALOG from uc_state \u2500\u2500\u2500\u2500\n", "# P1-16: previously this cell scanned every mas-*/ka-* endpoint and every Genie\n", "# space in the workspace and granted to the app SP. Two pain points:\n", - "# 1. Cross-tenant blast radius — a colleague's `mas-xxxxxxxx-endpoint` in the\n", + "# 1. Cross-tenant blast radius \u2014 a colleague's `mas-xxxxxxxx-endpoint` in the\n", "# same workspace would silently get our app SP granted CAN_QUERY.\n", "# 2. The grant list grew unboundedly with every catalog deployed, so\n", "# runtime got slower and quota errors started surfacing.\n", - "# uc_state stores the resources the current catalog's stages created — scope\n", + "# uc_state stores the resources the current catalog's stages created \u2014 scope\n", "# the grants to that set only.\n", "import json as _json\n", "\n", @@ -416,12 +417,14 @@ " \"\"\").collect()\n", " return [_json.loads(r.resource_data) for r in rows]\n", " except Exception as e:\n", - " print(f\"⚠️ uc_state lookup failed for {resource_type}: {e}\")\n", + " print(f\"\u26a0\ufe0f uc_state lookup failed for {resource_type}: {e}\")\n", " return []\n", "\n", - "# Build the set of endpoint names this catalog owns.\n", + "# Build the set of serving endpoint names this catalog owns.\n", "# - multi_agent_supervisors: row has `endpoint_name`\n", - "# - knowledge_assistants: row has `tile_id` → endpoint name is `ka-{tile_id[:8]}-endpoint`\n", + "# - knowledge_assistants: row has `tile_id` -> endpoint name is `ka-{tile_id[:8]}-endpoint`\n", + "# Refund and complaint custom agents are Databricks Apps now, so they are\n", + "# granted below via the Apps permissions API instead of CAN_QUERY here.\n", "owned_endpoint_names = set()\n", "for row in _uc_state_rows(\"multi_agent_supervisors\"):\n", " n = row.get(\"endpoint_name\")\n", @@ -431,9 +434,11 @@ " tile_id = row.get(\"tile_id\") or \"\"\n", " if tile_id:\n", " owned_endpoint_names.add(f\"ka-{tile_id[:8]}-endpoint\")\n", - "# Custom agents deployed via agents.deploy() — names are deterministic from CATALOG\n", - "for ep in [f\"{CATALOG}_refund_agent\", f\"{CATALOG}_complaint_agent\"]:\n", - " owned_endpoint_names.add(ep)\n", + "\n", + "owned_agent_app_names = {\n", + " refund_agent_app_name(CATALOG),\n", + " complaint_agent_app_name(CATALOG),\n", + "}\n", "\n", "# Build the set of Genie spaces this catalog owns (`space_id`).\n", "# uc_state may contain stale rows from prior deploys: the `genie_spaces` stage\n", @@ -462,9 +467,9 @@ " f\"the {len(owned_space_ids)} most-recent space(s) per title.\"\n", " )\n", "\n", - "print(f\"uc_state: {len(owned_endpoint_names)} owned endpoints, {len(owned_space_ids)} owned Genie spaces\")\n", + "print(f\"uc_state: {len(owned_endpoint_names)} owned endpoints, {len(owned_space_ids)} owned Genie spaces, {len(owned_agent_app_names)} custom agent apps\")\n", "\n", - "# ── Grant CAN_QUERY on owned mas-*/ka-* serving endpoints ─────────────────────\n", + "# \u2500\u2500 Grant CAN_QUERY on owned mas-*/ka-* serving endpoints \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# IMPORTANT: update_permissions requires the endpoint UUID (ep.id), not the name.\n", "# We still call w.serving_endpoints.list() because the permissions API needs\n", "# the UUID, but we filter to only the names we own.\n", @@ -475,14 +480,14 @@ "# (catches drift between uc_state and reality, e.g. someone deleted an endpoint).\n", "missing = owned_endpoint_names - {ep.name for ep in target_eps}\n", "if missing:\n", - " print(f\"⚠️ uc_state lists endpoints not present in workspace (skipping): {sorted(missing)}\")\n", + " print(f\"\u26a0\ufe0f uc_state lists endpoints not present in workspace (skipping): {sorted(missing)}\")\n", "print(f\"Granting CAN_QUERY on {len(target_eps)} owned endpoints\")\n", "for ep in target_eps:\n", " ok = _grant_can_query(ep.id, ep.name, app_sp_id)\n", " if not ok:\n", " ep_failures.append(ep.name)\n", "\n", - "# ── Grant CAN_RUN on owned Genie spaces ───────────────────────────────────────\n", + "# \u2500\u2500 Grant CAN_RUN on owned Genie spaces \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "genie_failures = []\n", "genie_orphans = []\n", "print(f\"Granting CAN_RUN on {len(owned_space_ids)} owned Genie spaces\")\n", @@ -498,8 +503,8 @@ " f\"(deleted or owned by another identity): {genie_orphans}\"\n", " )\n", "\n", - "# ── Fail loudly if any grants didn't land ─────────────────────────────────────\n", - "# Only real failures count — orphans are treated as warnings (stale uc_state).\n", + "# \u2500\u2500 Fail loudly if any grants didn't land \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", + "# Only real failures count \u2014 orphans are treated as warnings (stale uc_state).\n", "if ep_failures or genie_failures:\n", " raise RuntimeError(\n", " f\"Permission grants failed!\\n\"\n", @@ -578,14 +583,236 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "##### Write app.yaml — assembles all resource IDs from uc_state before deploy" + "##### Write app.yaml \u2014 assembles all resource IDs from uc_state before deploy" ] }, { "cell_type": "code", "metadata": {}, "source": [ - "import json as _json, re as _re, os as _os\n\ndef _latest_from_uc_state(resource_type, key):\n try:\n df = spark.sql(f\"\"\"\n SELECT resource_data FROM {CATALOG}._internal_state.resources\n WHERE resource_type = '{resource_type}'\n ORDER BY created_at DESC\n \"\"\")\n for row in df.collect():\n val = _json.loads(row.resource_data).get(key, \"\")\n if val:\n return val\n except Exception as e:\n print(f\"⚠️ uc_state lookup failed for {resource_type}/{key}: {e}\")\n return \"\"\n\ndef _ka_tile_id(ka_name):\n try:\n df = spark.sql(f\"\"\"\n SELECT resource_data FROM {CATALOG}._internal_state.resources\n WHERE resource_type = 'knowledge_assistants'\n ORDER BY created_at DESC\n \"\"\")\n for row in df.collect():\n info = _json.loads(row.resource_data)\n if info.get(\"name\") == ka_name:\n return info.get(\"tile_id\", \"\")\n except Exception as e:\n print(f\"⚠️ Could not read KA tile {ka_name}: {e}\")\n return \"\"\n\n# Supervisor\nsupervisor_endpoint = _latest_from_uc_state(\"multi_agent_supervisors\", \"endpoint_name\")\nsupervisor_tile_id = _latest_from_uc_state(\"multi_agent_supervisors\", \"tile_id\")\n\nsupervisor_mlflow_experiment_id = \"\"\nif supervisor_tile_id:\n try:\n _tile = w.api_client.do(\"GET\", f\"/api/2.0/tiles/{supervisor_tile_id}\")\n supervisor_mlflow_experiment_id = str(_tile.get(\"mlflow_experiment_id\") or \"\")\n except Exception as e:\n print(f\"⚠️ Could not fetch mlflow_experiment_id from tile: {e}\")\n\n# Genie spaces (titles must match stages/genie_spaces.ipynb)\n_revenue_title = f\"Revenue & Orders Intelligence ({CATALOG})\"\n_ops_title = f\"Operations Intelligence ({CATALOG})\"\n_menu_title = f\"Menu & Safety Intelligence ({CATALOG})\"\ngenie_revenue_id = genie_ops_id = genie_menu_id = \"\"\ntry:\n df = spark.sql(f\"\"\"\n SELECT resource_data FROM {CATALOG}._internal_state.resources\n WHERE resource_type = 'genie_spaces'\n ORDER BY created_at DESC\n \"\"\")\n for row in df.collect():\n info = _json.loads(row.resource_data)\n title = info.get(\"title\")\n if title == _revenue_title and not genie_revenue_id:\n genie_revenue_id = info.get(\"space_id\", \"\")\n elif title == _ops_title and not genie_ops_id:\n genie_ops_id = info.get(\"space_id\", \"\")\n elif title == _menu_title and not genie_menu_id:\n genie_menu_id = info.get(\"space_id\", \"\")\nexcept Exception as e:\n print(f\"⚠️ Could not read genie IDs: {e}\")\n\n# KA tile IDs (one per Knowledge Assistant created by stages/knowledge_agents.ipynb)\nka_inspection = _ka_tile_id(f\"{CATALOG}-inspection-knowledge\")\nka_menu = _ka_tile_id(f\"{CATALOG}-menu-knowledge\")\nka_legal = _ka_tile_id(f\"{CATALOG}-legal\")\nka_regulatory = _ka_tile_id(f\"{CATALOG}-regulatory\")\nka_audits = _ka_tile_id(f\"{CATALOG}-audits\")\nka_consultancy = _ka_tile_id(f\"{CATALOG}-consultancy\")\n\n# Operations dashboard embedded in the app's opt-in \"Ops Dashboard\" tab.\n# We publish TWO variants — light + dark — because Databricks embedded\n# dashboards always render their `light` theme slot (the `?theme=` URL\n# param is ignored). The dark variant bakes dark-palette tokens into\n# the `light` slot so the iframe renders with dark colors. The frontend\n# swaps the iframe src when the user toggles the theme. The Revenue\n# card on the always-visible main view stays on Chart.js because the\n# embed's other limitations (separate workspace login on first view,\n# fixed canvas width) were deal-breakers for an always-visible panel.\nops_dashboard_id = \"\" # light variant — default, used by app's light theme\nops_dashboard_id_dark = \"\" # dark variant — used when app theme is dark\nops_dashboard_page = \"\"\n_dash_titles = {\n \"light\": f\"Casper's Kitchen - Operations ({CATALOG})\",\n \"dark\": f\"Casper's Kitchen - Operations - Dark ({CATALOG})\",\n}\ntry:\n _by_title = {(_d.display_name or \"\"): _d for _d in w.lakeview.list()}\n for _variant, _title in _dash_titles.items():\n _d = _by_title.get(_title)\n if not _d:\n print(f\"⚠️ Dashboard '{_title}' not found via lakeview.list()\")\n continue\n _id = _d.dashboard_id or \"\"\n if _variant == \"light\":\n ops_dashboard_id = _id\n else:\n ops_dashboard_id_dark = _id\n try:\n # embed_credentials=True so the iframe runs as the dashboard\n # owner — query execution doesn't require the viewer to\n # authenticate to cloud.databricks.com separately.\n w.lakeview.publish(dashboard_id=_id, embed_credentials=True)\n print(f\"✅ Ops dashboard ({_variant}) (re-)published: {_id}\")\n except Exception as _pe:\n # ALREADY_PUBLISHED and similar racy states are fine.\n print(f\"⚠️ Could not (re-)publish {_variant} dashboard {_id}: {_pe}\")\nexcept Exception as _e:\n print(f\"⚠️ Could not look up ops dashboards: {_e}\")\n\n# Lakebase endpoint path — deterministic from CATALOG, points at the\n# shared Autoscaling project (`-caspers`) created by\n# stages/lakebase_project.ipynb. Reuses OPS_PROJECT_ID computed at the top\n# of this notebook.\nlakebase_endpoint_path = OPS_ENDPOINT_PATH\n\n# Custom agent endpoints — names follow the DABs widget default pattern\nrefund_agent_endpoint = f\"{CATALOG}_refund_agent\"\ncomplaint_agent_endpoint = f\"{CATALOG}_complaint_agent\"\n\n# Refund Manager app URL — look up by name, empty string if not yet deployed\nrefund_manager_app_name = _re.sub(r'-+', '-', _re.sub(r'[^a-z0-9-]', '-', f'refundmanager-{CATALOG}'.lower())).strip('-')[:30]\ntry:\n _rm_app = w.apps.get(refund_manager_app_name)\n refund_manager_app_url = getattr(_rm_app, 'url', '') or ''\n print(f'Refund Manager app: {refund_manager_app_url}')\nexcept Exception:\n refund_manager_app_url = ''\n print(f'Refund Manager app not found ({refund_manager_app_name}) — REFUND_MANAGER_APP_URL will be empty')\n\napp_yaml_path = _os.path.abspath(\"../apps/caspers-ops-dashboard/app.yaml\")\napp_yaml_contents = f\"\"\"command:\n - uvicorn\n - app.main:app\nenv:\n - name: LAKEBASE_ENDPOINT_PATH\n value: '{lakebase_endpoint_path}'\n - name: LAKEBASE_DATABASE_NAME\n value: '{OPS_DATABASE_NAME}'\n - name: DATABRICKS_CATALOG\n value: '{CATALOG}'\n - name: SUPERVISOR_ENDPOINT\n value: '{supervisor_endpoint}'\n - name: SUPERVISOR_TILE_ID\n value: '{supervisor_tile_id}'\n - name: SUPERVISOR_MLFLOW_EXPERIMENT_ID\n value: '{supervisor_mlflow_experiment_id}'\n - name: GENIE_ID_REVENUE\n value: '{genie_revenue_id}'\n - name: GENIE_ID_OPS\n value: '{genie_ops_id}'\n - name: GENIE_ID_MENU\n value: '{genie_menu_id}'\n - name: KA_ID_INSPECTION\n value: '{ka_inspection}'\n - name: KA_ID_MENU\n value: '{ka_menu}'\n - name: KA_ID_LEGAL\n value: '{ka_legal}'\n - name: KA_ID_REGULATORY\n value: '{ka_regulatory}'\n - name: KA_ID_AUDITS\n value: '{ka_audits}'\n - name: KA_ID_CONSULTANCY\n value: '{ka_consultancy}'\n - name: DATABRICKS_WAREHOUSE_ID\n value: '{warehouse.id}'\n - name: OPS_DASHBOARD_ID\n value: '{ops_dashboard_id}'\n - name: OPS_DASHBOARD_ID_DARK\n value: '{ops_dashboard_id_dark}'\n - name: OPS_DASHBOARD_PAGE\n value: '{ops_dashboard_page}'\n - name: REFUND_AGENT_ENDPOINT\n value: '{refund_agent_endpoint}'\n - name: COMPLAINT_AGENT_ENDPOINT\n value: '{complaint_agent_endpoint}'\n - name: REFUND_MANAGER_APP_URL\n value: '{refund_manager_app_url}'\n\"\"\"\nwith open(app_yaml_path, \"w\") as _f:\n _f.write(app_yaml_contents)\n\nprint(f\"✅ Wrote app.yaml to {app_yaml_path}\")\nprint(f\" Supervisor endpoint: {supervisor_endpoint}\")\nprint(f\" Supervisor tile: {supervisor_tile_id}\")\nprint(f\" MLflow experiment: {supervisor_mlflow_experiment_id}\")\nprint(f\" Lakebase endpoint: {lakebase_endpoint_path}\")\nprint(f\" Revenue Genie: {genie_revenue_id}\")\nprint(f\" Ops Genie: {genie_ops_id}\")\nprint(f\" Menu Genie: {genie_menu_id}\")\nprint(f\" KA Inspection: {ka_inspection}\")\nprint(f\" KA Menu: {ka_menu}\")\nprint(f\" KA Legal: {ka_legal}\")\nprint(f\" KA Regulatory: {ka_regulatory}\")\nprint(f\" KA Audits: {ka_audits}\")\nprint(f\" KA Consultancy: {ka_consultancy}\")\nprint(f\" Warehouse: {warehouse.id}\")\nprint(f\" Ops Dashboard (light):{ops_dashboard_id}\")\nprint(f\" Ops Dashboard (dark): {ops_dashboard_id_dark}\")\nprint(f\" Ops Dashboard page: {ops_dashboard_page or '(default landing page)'}\")" + "import json as _json, re as _re, os as _os\n", + "from agent_app_client import refund_agent_app_name, complaint_agent_app_name\n", + "\n", + "def _latest_from_uc_state(resource_type, key):\n", + " try:\n", + " df = spark.sql(f\"\"\"\n", + " SELECT resource_data FROM {CATALOG}._internal_state.resources\n", + " WHERE resource_type = '{resource_type}'\n", + " ORDER BY created_at DESC\n", + " \"\"\")\n", + " for row in df.collect():\n", + " val = _json.loads(row.resource_data).get(key, \"\")\n", + " if val:\n", + " return val\n", + " except Exception as e:\n", + " print(f\"\u26a0\ufe0f uc_state lookup failed for {resource_type}/{key}: {e}\")\n", + " return \"\"\n", + "\n", + "def _ka_tile_id(ka_name):\n", + " try:\n", + " df = spark.sql(f\"\"\"\n", + " SELECT resource_data FROM {CATALOG}._internal_state.resources\n", + " WHERE resource_type = 'knowledge_assistants'\n", + " ORDER BY created_at DESC\n", + " \"\"\")\n", + " for row in df.collect():\n", + " info = _json.loads(row.resource_data)\n", + " if info.get(\"name\") == ka_name:\n", + " return info.get(\"tile_id\", \"\")\n", + " except Exception as e:\n", + " print(f\"\u26a0\ufe0f Could not read KA tile {ka_name}: {e}\")\n", + " return \"\"\n", + "\n", + "# Supervisor\n", + "supervisor_endpoint = _latest_from_uc_state(\"multi_agent_supervisors\", \"endpoint_name\")\n", + "supervisor_tile_id = _latest_from_uc_state(\"multi_agent_supervisors\", \"tile_id\")\n", + "\n", + "supervisor_mlflow_experiment_id = \"\"\n", + "if supervisor_tile_id:\n", + " try:\n", + " _tile = w.api_client.do(\"GET\", f\"/api/2.0/tiles/{supervisor_tile_id}\")\n", + " supervisor_mlflow_experiment_id = str(_tile.get(\"mlflow_experiment_id\") or \"\")\n", + " except Exception as e:\n", + " print(f\"\u26a0\ufe0f Could not fetch mlflow_experiment_id from tile: {e}\")\n", + "\n", + "# Genie spaces (titles must match stages/genie_spaces.ipynb)\n", + "_revenue_title = f\"Revenue & Orders Intelligence ({CATALOG})\"\n", + "_ops_title = f\"Operations Intelligence ({CATALOG})\"\n", + "_menu_title = f\"Menu & Safety Intelligence ({CATALOG})\"\n", + "genie_revenue_id = genie_ops_id = genie_menu_id = \"\"\n", + "try:\n", + " df = spark.sql(f\"\"\"\n", + " SELECT resource_data FROM {CATALOG}._internal_state.resources\n", + " WHERE resource_type = 'genie_spaces'\n", + " ORDER BY created_at DESC\n", + " \"\"\")\n", + " for row in df.collect():\n", + " info = _json.loads(row.resource_data)\n", + " title = info.get(\"title\")\n", + " if title == _revenue_title and not genie_revenue_id:\n", + " genie_revenue_id = info.get(\"space_id\", \"\")\n", + " elif title == _ops_title and not genie_ops_id:\n", + " genie_ops_id = info.get(\"space_id\", \"\")\n", + " elif title == _menu_title and not genie_menu_id:\n", + " genie_menu_id = info.get(\"space_id\", \"\")\n", + "except Exception as e:\n", + " print(f\"\u26a0\ufe0f Could not read genie IDs: {e}\")\n", + "\n", + "# KA tile IDs (one per Knowledge Assistant created by stages/knowledge_agents.ipynb)\n", + "ka_inspection = _ka_tile_id(f\"{CATALOG}-inspection-knowledge\")\n", + "ka_menu = _ka_tile_id(f\"{CATALOG}-menu-knowledge\")\n", + "ka_legal = _ka_tile_id(f\"{CATALOG}-legal\")\n", + "ka_regulatory = _ka_tile_id(f\"{CATALOG}-regulatory\")\n", + "ka_audits = _ka_tile_id(f\"{CATALOG}-audits\")\n", + "ka_consultancy = _ka_tile_id(f\"{CATALOG}-consultancy\")\n", + "\n", + "# Operations dashboard embedded in the app's opt-in \"Ops Dashboard\" tab.\n", + "# We publish TWO variants \u2014 light + dark \u2014 because Databricks embedded\n", + "# dashboards always render their `light` theme slot (the `?theme=` URL\n", + "# param is ignored). The dark variant bakes dark-palette tokens into\n", + "# the `light` slot so the iframe renders with dark colors. The frontend\n", + "# swaps the iframe src when the user toggles the theme. The Revenue\n", + "# card on the always-visible main view stays on Chart.js because the\n", + "# embed's other limitations (separate workspace login on first view,\n", + "# fixed canvas width) were deal-breakers for an always-visible panel.\n", + "ops_dashboard_id = \"\" # light variant \u2014 default, used by app's light theme\n", + "ops_dashboard_id_dark = \"\" # dark variant \u2014 used when app theme is dark\n", + "ops_dashboard_page = \"\"\n", + "_dash_titles = {\n", + " \"light\": f\"Casper's Kitchen - Operations ({CATALOG})\",\n", + " \"dark\": f\"Casper's Kitchen - Operations - Dark ({CATALOG})\",\n", + "}\n", + "try:\n", + " _by_title = {(_d.display_name or \"\"): _d for _d in w.lakeview.list()}\n", + " for _variant, _title in _dash_titles.items():\n", + " _d = _by_title.get(_title)\n", + " if not _d:\n", + " print(f\"\u26a0\ufe0f Dashboard '{_title}' not found via lakeview.list()\")\n", + " continue\n", + " _id = _d.dashboard_id or \"\"\n", + " if _variant == \"light\":\n", + " ops_dashboard_id = _id\n", + " else:\n", + " ops_dashboard_id_dark = _id\n", + " try:\n", + " # embed_credentials=True so the iframe runs as the dashboard\n", + " # owner \u2014 query execution doesn't require the viewer to\n", + " # authenticate to cloud.databricks.com separately.\n", + " w.lakeview.publish(dashboard_id=_id, embed_credentials=True)\n", + " print(f\"\u2705 Ops dashboard ({_variant}) (re-)published: {_id}\")\n", + " except Exception as _pe:\n", + " # ALREADY_PUBLISHED and similar racy states are fine.\n", + " print(f\"\u26a0\ufe0f Could not (re-)publish {_variant} dashboard {_id}: {_pe}\")\n", + "except Exception as _e:\n", + " print(f\"\u26a0\ufe0f Could not look up ops dashboards: {_e}\")\n", + "\n", + "# Lakebase endpoint path \u2014 deterministic from CATALOG, points at the\n", + "# shared Autoscaling project (`-caspers`) created by\n", + "# stages/lakebase_project.ipynb. Reuses OPS_PROJECT_ID computed at the top\n", + "# of this notebook.\n", + "lakebase_endpoint_path = OPS_ENDPOINT_PATH\n", + "\n", + "# Custom agent Apps \u2014 names are deterministic and catalog-scoped.\n", + "refund_agent_app = refund_agent_app_name(CATALOG)\n", + "complaint_agent_app = complaint_agent_app_name(CATALOG)\n", + "\n", + "def _app_url_or_empty(app_name):\n", + " try:\n", + " return getattr(w.apps.get(app_name), \"url\", \"\") or \"\"\n", + " except Exception as e:\n", + " print(f\"\u26a0\ufe0f App {app_name} not found \u2014 URL will be empty: {e}\")\n", + " return \"\"\n", + "\n", + "refund_agent_app_url = _app_url_or_empty(refund_agent_app)\n", + "complaint_agent_app_url = _app_url_or_empty(complaint_agent_app)\n", + "\n", + "# Refund Manager app URL \u2014 look up by name, empty string if not yet deployed\n", + "refund_manager_app_name = _re.sub(r'-+', '-', _re.sub(r'[^a-z0-9-]', '-', f'refundmanager-{CATALOG}'.lower())).strip('-')[:30]\n", + "try:\n", + " _rm_app = w.apps.get(refund_manager_app_name)\n", + " refund_manager_app_url = getattr(_rm_app, 'url', '') or ''\n", + " print(f'Refund Manager app: {refund_manager_app_url}')\n", + "except Exception:\n", + " refund_manager_app_url = ''\n", + " print(f'Refund Manager app not found ({refund_manager_app_name}) \u2014 REFUND_MANAGER_APP_URL will be empty')\n", + "\n", + "app_yaml_path = _os.path.abspath(\"../apps/caspers-ops-dashboard/app.yaml\")\n", + "app_yaml_contents = f\"\"\"command:\n", + " - uvicorn\n", + " - app.main:app\n", + "env:\n", + " - name: LAKEBASE_ENDPOINT_PATH\n", + " value: '{lakebase_endpoint_path}'\n", + " - name: LAKEBASE_DATABASE_NAME\n", + " value: '{OPS_DATABASE_NAME}'\n", + " - name: DATABRICKS_CATALOG\n", + " value: '{CATALOG}'\n", + " - name: SUPERVISOR_ENDPOINT\n", + " value: '{supervisor_endpoint}'\n", + " - name: SUPERVISOR_TILE_ID\n", + " value: '{supervisor_tile_id}'\n", + " - name: SUPERVISOR_MLFLOW_EXPERIMENT_ID\n", + " value: '{supervisor_mlflow_experiment_id}'\n", + " - name: GENIE_ID_REVENUE\n", + " value: '{genie_revenue_id}'\n", + " - name: GENIE_ID_OPS\n", + " value: '{genie_ops_id}'\n", + " - name: GENIE_ID_MENU\n", + " value: '{genie_menu_id}'\n", + " - name: KA_ID_INSPECTION\n", + " value: '{ka_inspection}'\n", + " - name: KA_ID_MENU\n", + " value: '{ka_menu}'\n", + " - name: KA_ID_LEGAL\n", + " value: '{ka_legal}'\n", + " - name: KA_ID_REGULATORY\n", + " value: '{ka_regulatory}'\n", + " - name: KA_ID_AUDITS\n", + " value: '{ka_audits}'\n", + " - name: KA_ID_CONSULTANCY\n", + " value: '{ka_consultancy}'\n", + " - name: DATABRICKS_WAREHOUSE_ID\n", + " value: '{warehouse.id}'\n", + " - name: OPS_DASHBOARD_ID\n", + " value: '{ops_dashboard_id}'\n", + " - name: OPS_DASHBOARD_ID_DARK\n", + " value: '{ops_dashboard_id_dark}'\n", + " - name: OPS_DASHBOARD_PAGE\n", + " value: '{ops_dashboard_page}'\n", + " - name: REFUND_AGENT_APP_NAME\n", + " value: '{refund_agent_app}'\n", + " - name: REFUND_AGENT_APP_URL\n", + " value: '{refund_agent_app_url}'\n", + " - name: COMPLAINT_AGENT_APP_NAME\n", + " value: '{complaint_agent_app}'\n", + " - name: COMPLAINT_AGENT_APP_URL\n", + " value: '{complaint_agent_app_url}'\n", + " - name: REFUND_MANAGER_APP_URL\n", + " value: '{refund_manager_app_url}'\n", + "\"\"\"\n", + "with open(app_yaml_path, \"w\") as _f:\n", + " _f.write(app_yaml_contents)\n", + "\n", + "print(f\"\u2705 Wrote app.yaml to {app_yaml_path}\")\n", + "print(f\" Supervisor endpoint: {supervisor_endpoint}\")\n", + "print(f\" Supervisor tile: {supervisor_tile_id}\")\n", + "print(f\" MLflow experiment: {supervisor_mlflow_experiment_id}\")\n", + "print(f\" Lakebase endpoint: {lakebase_endpoint_path}\")\n", + "print(f\" Revenue Genie: {genie_revenue_id}\")\n", + "print(f\" Ops Genie: {genie_ops_id}\")\n", + "print(f\" Menu Genie: {genie_menu_id}\")\n", + "print(f\" KA Inspection: {ka_inspection}\")\n", + "print(f\" KA Menu: {ka_menu}\")\n", + "print(f\" KA Legal: {ka_legal}\")\n", + "print(f\" KA Regulatory: {ka_regulatory}\")\n", + "print(f\" KA Audits: {ka_audits}\")\n", + "print(f\" KA Consultancy: {ka_consultancy}\")\n", + "print(f\" Warehouse: {warehouse.id}\")\n", + "print(f\" Ops Dashboard (light):{ops_dashboard_id}\")\n", + "print(f\" Ops Dashboard (dark): {ops_dashboard_id_dark}\")\n", + "print(f\" Ops Dashboard page: {ops_dashboard_page or '(default landing page)'}\")\n", + "print(f\" Refund Agent app: {refund_agent_app} ({refund_agent_app_url or 'URL pending'})\")\n", + "print(f\" Complaint Agent app: {complaint_agent_app} ({complaint_agent_app_url or 'URL pending'})\")" ], "execution_count": null, "outputs": [] @@ -624,14 +851,14 @@ "print(f\" URL: {app_status.url if hasattr(app_status, 'url') else 'Check Databricks Apps UI'}\")\n", "\n", "# Discover Domain tagging. The Operational Dashboard is the operator's\n", - "# single pane of glass over the orchestrated agents — pure operations\n", + "# single pane of glass over the orchestrated agents \u2014 pure operations\n", "# domain. Genie spaces, KAs, dashboards used inside the app each carry\n", "# their own tags so they ALSO surface in their respective domains. See\n", - "# SETUP.ipynb §4 for the manual one-time Domain creation step.\n", + "# SETUP.ipynb \u00a74 for the manual one-time Domain creation step.\n", "import sys as _sys, os as _os\n", "_sys.path.insert(0, _os.path.abspath(\"..\")) # stages/ -> repo root\n", "from utils.domain_tags import ensure_domain_tag_policies, tag_workspace_entity\n", - "print(\"\\n— Tagging operational dashboard app for Discover Domains —\")\n", + "print(\"\\n\u2014 Tagging operational dashboard app for Discover Domains \u2014\")\n", "ensure_domain_tag_policies(w, verbose=False)\n", "tag_workspace_entity(w, \"apps\", app_status.name, [\"operations\"],\n", " label=f\"app {app_status.name!r}\")\n", diff --git a/stages/refund_evaluation.ipynb b/stages/refund_evaluation.ipynb index debb72d..589927e 100644 --- a/stages/refund_evaluation.ipynb +++ b/stages/refund_evaluation.ipynb @@ -4,29 +4,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Refund Agent — Evaluation\n", + "### Refund Agent \u2014 Evaluation\n", "\n", - "Standalone evaluation task for the deployed Refund Agent. Runs after\n", - "`Refund_Recommender_Agent` in the pipeline so the endpoint is up before\n", - "we hit it.\n", + "Standalone evaluation task for the deployed Refund Agent App. Runs after\n", + "`Refund_Recommender_Agent` in the pipeline so the app is ready before we hit it.\n", "\n", - "- **Gated by `SKIP_EVAL`** (default `\"true\"`) — flip to `\"false\"` to actually\n", + "- **Gated by `SKIP_EVAL`** (default `\"true\"`) \u2014 flip to `\"false\"` to actually\n", " run the evaluation. The default skip is a deliberate rate-limit safeguard:\n", " evaluation is the burst-iest LLM consumer in the bundle and can hit the\n", - " shared FMAPI endpoint's QPS cap when refund + complaint eval run\n", - " back-to-back.\n", - "- Calls the deployed endpoint via `mlflow.deployments.get_deploy_client`\n", - " rather than importing the agent module — so this notebook is fully\n", - " self-contained and tests the actual production endpoint shape.\n", - "- Eval results land in `/Shared/{CATALOG}_refund_agent_dev`." + " shared judge endpoint's QPS cap when refund + complaint eval run back-to-back.\n", + "- Calls the deployed Databricks App via its MLflow AgentServer `/responses`\n", + " contract rather than importing the agent module.\n", + "- Eval results land in `/Shared/{CATALOG}_refund_agent_dev`.\n" ] }, { "cell_type": "code", "metadata": {}, "source": [ - "%pip install -U -qqqq mlflow-skinny[databricks] databricks-sdk\n", - "dbutils.library.restartPython()" + "%pip install -U -qqqq mlflow-skinny[databricks] databricks-sdk requests\n", + "dbutils.library.restartPython()\n" ], "execution_count": null, "outputs": [] @@ -36,7 +33,13 @@ "metadata": {}, "source": [ "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "ENDPOINT_NAME = dbutils.widgets.get(\"REFUND_AGENT_ENDPOINT_NAME\")\n", + "\n", + "import os\n", + "import sys\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import refund_agent_app_name\n", + "\n", + "APP_NAME = refund_agent_app_name(CATALOG)\n", "\n", "try:\n", " SKIP_EVAL = dbutils.widgets.get(\"SKIP_EVAL\").strip().lower() == \"true\"\n", @@ -49,10 +52,10 @@ "mlflow.set_experiment(DEV_EXPERIMENT)\n", "\n", "print(f\"Catalog: {CATALOG}\")\n", - "print(f\"Endpoint: {ENDPOINT_NAME}\")\n", + "print(f\"App: {APP_NAME}\")\n", "print(f\"SKIP_EVAL: {SKIP_EVAL}\")\n", "print(f\"Dev experiment: {DEV_EXPERIMENT}\")\n", - "print(f\"MLflow version: {mlflow.__version__}\")" + "print(f\"MLflow version: {mlflow.__version__}\")\n" ], "execution_count": null, "outputs": [] @@ -63,7 +66,7 @@ "source": [ "if SKIP_EVAL:\n", " print(\n", - " f\"⏭ SKIP_EVAL=true — skipping mlflow.genai.evaluate to avoid the ~20 \"\n", + " f\"\u23ed SKIP_EVAL=true \u2014 skipping mlflow.genai.evaluate to avoid the ~20 \"\n", " f\"LM-call eval burst. Pass --params \\\"SKIP_EVAL=false\\\" to actually run \"\n", " f\"the evaluation (e.g. before a demo or after a prompt change).\"\n", " )\n", @@ -84,25 +87,34 @@ "POLL_INTERVAL_S = 15\n", "MAX_POLLS = 40\n", "\n", - "print(f\"Polling endpoint readiness ({MAX_POLLS * POLL_INTERVAL_S // 60} min max)…\")\n", + "print(f\"Polling app readiness ({MAX_POLLS * POLL_INTERVAL_S // 60} min max)...\")\n", + "\n", + "\n", + "def _app_state(a):\n", + " cs = getattr(a, \"compute_status\", None)\n", + " s = getattr(cs, \"state\", None) if cs is not None else None\n", + " if s is None:\n", + " s = getattr(a, \"state\", None)\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", + "\n", "\n", "for attempt in range(1, MAX_POLLS + 1):\n", " try:\n", - " ep = w.serving_endpoints.get(ENDPOINT_NAME)\n", - " ready = str(getattr(ep.state, \"ready\", \"\")).upper() if ep.state else \"\"\n", - " cfg_update = str(getattr(ep.state, \"config_update\", \"\")).upper() if ep.state else \"\"\n", - " if \"READY\" in ready:\n", - " print(f\" ✅ Endpoint READY (config_update={cfg_update or 'n/a'})\")\n", + " app = w.apps.get(APP_NAME)\n", + " state = _app_state(app)\n", + " url = getattr(app, \"url\", \"\")\n", + " if state in (\"ACTIVE\", \"RUNNING\", \"READY\") and url:\n", + " print(f\" App ready: state={state}, url={url}\")\n", " break\n", - " print(f\" ⏳ [{attempt}/{MAX_POLLS}] ready={ready}, config_update={cfg_update}\")\n", + " print(f\" [{attempt}/{MAX_POLLS}] state={state or 'unknown'}, url={url or 'pending'}\")\n", " except Exception as e:\n", - " print(f\" ⚠️ poll error: {type(e).__name__}: {e}\")\n", + " print(f\" poll error: {type(e).__name__}: {e}\")\n", " time.sleep(POLL_INTERVAL_S)\n", "else:\n", " raise RuntimeError(\n", - " f\"Endpoint {ENDPOINT_NAME} did not become READY within \"\n", + " f\"App {APP_NAME} did not become ready within \"\n", " f\"{MAX_POLLS * POLL_INTERVAL_S // 60} minutes.\"\n", - " )" + " )\n" ], "execution_count": null, "outputs": [] @@ -113,18 +125,15 @@ "source": [ "import os\n", "import random\n", - "from mlflow.deployments import get_deploy_client\n", + "import time\n", + "\n", + "sys.path.append(os.path.abspath(\"../utils\"))\n", + "from agent_app_client import call_agent_app_text\n", "\n", "os.environ[\"MLFLOW_GENAI_EVAL_MAX_WORKERS\"] = \"1\"\n", "os.environ[\"MLFLOW_GENAI_EVAL_MAX_SCORER_WORKERS\"] = \"1\"\n", - "# Bump client-side HTTP timeouts. Defaults (120s) trip up agent endpoints and\n", - "# the LLM judge calls inside scorers (Safety / RelevanceToQuery / Guidelines).\n", - "os.environ.setdefault(\"MLFLOW_DEPLOYMENT_PREDICT_TIMEOUT\", \"600\")\n", - "os.environ.setdefault(\"MLFLOW_DEPLOYMENT_PREDICT_TOTAL_TIMEOUT\", \"900\")\n", "os.environ.setdefault(\"MLFLOW_HTTP_REQUEST_TIMEOUT\", \"600\")\n", "\n", - "_deploy_client = get_deploy_client(\"databricks\")\n", - "\n", "\n", "def _is_rate_limit(exc: BaseException) -> bool:\n", " msg = str(exc)\n", @@ -135,24 +144,33 @@ " )\n", "\n", "\n", + "def _input_messages(value):\n", + " if isinstance(value, dict):\n", + " return value.get(\"messages\") or value.get(\"input\") or [value]\n", + " return value\n", + "\n", + "\n", "def predict_fn(messages):\n", - " \"\"\"Call the deployed Refund Agent endpoint with retry-on-rate-limit.\"\"\"\n", + " \"\"\"Call the deployed Refund Agent App with retry-on-rate-limit.\"\"\"\n", " max_attempts = 6\n", + " input_messages = _input_messages(messages)\n", " for attempt in range(max_attempts):\n", " try:\n", - " return _deploy_client.predict(\n", - " endpoint=ENDPOINT_NAME,\n", - " inputs={\"messages\": messages},\n", + " return call_agent_app_text(\n", + " app_name=APP_NAME,\n", + " input_messages=input_messages,\n", + " dbutils=dbutils,\n", + " timeout=600,\n", " )\n", " except Exception as exc:\n", " if not _is_rate_limit(exc) or attempt == max_attempts - 1:\n", " raise\n", " backoff = min(60, 2 ** attempt) + random.uniform(0, 1)\n", - " print(f\" ⚠️ rate-limited (attempt {attempt + 1}/{max_attempts}), sleeping {backoff:.1f}s\")\n", + " print(f\" rate-limited (attempt {attempt + 1}/{max_attempts}), sleeping {backoff:.1f}s\")\n", " time.sleep(backoff)\n", "\n", "\n", - "print(f\"✅ predict_fn defined — will call endpoint {ENDPOINT_NAME}\")" + "print(f\"predict_fn defined \u2014 will call app {APP_NAME}\")\n" ], "execution_count": null, "outputs": [] @@ -231,7 +249,7 @@ "except ImportError:\n", " RelevanceToQuery = None\n", " _has_relevance = False\n", - " print(\"⚠\\ufe0f RelevanceToQuery not available in this MLflow version — skipping it from the eval mix\")\n", + " print(\"\u26a0\\ufe0f RelevanceToQuery not available in this MLflow version \u2014 skipping it from the eval mix\")\n", "\n", "refund_reason = Guidelines(\n", " name=\"refund_reason\",\n", @@ -278,7 +296,7 @@ "if _has_relevance:\n", " _DATASET_SCORERS.append(RelevanceToQuery())\n", "\n", - "print(f\"✅ Scorers defined ({len(_DATASET_SCORERS)}): \" + \", \".join(s.name if hasattr(s, 'name') else type(s).__name__ for s in _DATASET_SCORERS))" + "print(f\"\u2705 Scorers defined ({len(_DATASET_SCORERS)}): \" + \", \".join(s.name if hasattr(s, 'name') else type(s).__name__ for s in _DATASET_SCORERS))" ], "execution_count": null, "outputs": [] @@ -305,7 +323,7 @@ " predict_fn=predict_fn,\n", " )\n", "\n", - "print(\"\\n✅ Evaluation complete\")\n", + "print(\"\\n\u2705 Evaluation complete\")\n", "print(f\" metrics: {getattr(results, 'metrics', 'N/A')}\")" ], "execution_count": null, @@ -315,14 +333,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Trace-derived eval — recent production calls\n", + "## Trace-derived eval \u2014 recent production calls\n", "\n", - "The block above runs a curated dataset against the deployed endpoint. That's\n", + "The block above runs a curated dataset against the deployed app. That's\n", "useful for catching regressions on **known-good scenarios**, but it doesn't\n", "tell us how the agent is doing on **real production traffic**.\n", "\n", "This section harvests the most recent successful traces from the prod\n", - "experiment (`/Shared/{CATALOG}_refund_agent_prod` — populated whenever the\n", + "experiment (`/Shared/{CATALOG}_refund_agent_prod` \u2014 populated whenever the\n", "streaming pipeline triggers a refund decision), turns them back into eval\n", "inputs, and re-runs the agent against them with a global quality\n", "guideline.\n", @@ -331,7 +349,7 @@ "for the Operational Supervisor, so all three agents share the same flywheel:\n", "\n", "```\n", - "Production traffic → MLflow traces → resampled eval dataset → quality scoring\n", + "Production traffic \u2192 MLflow traces \u2192 resampled eval dataset \u2192 quality scoring\n", "```\n", "\n", "Skipped gracefully if the prod experiment doesn't exist yet or has no\n", @@ -357,13 +375,13 @@ "\n", "if _exp is None:\n", " print(\n", - " f\"Skipping trace eval — experiment {prod_experiment_name} not found.\\n\"\n", - " \" Send some refund decisions through the deployed endpoint first.\"\n", + " f\"Skipping trace eval \u2014 experiment {prod_experiment_name} not found.\\n\"\n", + " \" Send some refund decisions through the deployed app first.\"\n", " )\n", "else:\n", " print(f\"Found prod traces experiment: {prod_experiment_name} ({_exp.experiment_id})\")\n", "\n", - " # `return_type=\"list\"` is required — `mlflow.search_traces()` returns a\n", + " # `return_type=\"list\"` is required \u2014 `mlflow.search_traces()` returns a\n", " # pandas DataFrame by default in MLflow 3, and iterating it yields column\n", " # names (strings). The iteration below expects `Trace` objects with\n", " # `.data.spans` / `.info.request_id`, so we explicitly request the list form.\n", @@ -381,9 +399,9 @@ " try:\n", " root = trace.data.spans[0]\n", " inputs_raw = root.inputs or {}\n", - " messages = inputs_raw.get(\"messages\", [])\n", + " items = inputs_raw.get(\"input\") or inputs_raw.get(\"messages\") or []\n", " question = next(\n", - " (m[\"content\"] for m in messages if m.get(\"role\") == \"user\"),\n", + " (m[\"content\"] for m in items if m.get(\"role\") == \"user\"),\n", " None,\n", " )\n", " if question:\n", @@ -394,7 +412,7 @@ " \"question\": question,\n", " })\n", " except Exception as _te:\n", - " print(f\" ⚠️ skipped trace {getattr(trace.info, 'request_id', '?')}: {_te}\")\n", + " print(f\" \u26a0\ufe0f skipped trace {getattr(trace.info, 'request_id', '?')}: {_te}\")\n", " continue\n", "\n", " traces_df = pd.DataFrame(records)\n", @@ -403,7 +421,7 @@ " if not traces_df.empty:\n", " TRACE_EVAL_LIMIT = 20\n", " trace_eval_data = [\n", - " {\"inputs\": {\"messages\": [{\"role\": \"user\", \"content\": row[\"question\"]}]}}\n", + " {\"inputs\": {\"input\": [{\"role\": \"user\", \"content\": row[\"question\"]}]}}\n", " for _, row in traces_df.head(TRACE_EVAL_LIMIT).iterrows()\n", " ]\n", " print(f\"Trace-derived eval dataset: {len(trace_eval_data)} records (cap={TRACE_EVAL_LIMIT}).\")\n", @@ -427,7 +445,7 @@ "\n", "if trace_eval_data:\n", " _trace_df = pd.DataFrame(\n", - " [{\"query\": row[\"inputs\"][\"messages\"][0][\"content\"]} for row in trace_eval_data]\n", + " [{\"query\": row[\"inputs\"].get(\"input\", row[\"inputs\"].get(\"messages\"))[0][\"content\"]} for row in trace_eval_data]\n", " )\n", " _trace_dataset = mlflow.data.from_pandas(\n", " _trace_df,\n", @@ -451,10 +469,10 @@ " ],\n", " predict_fn=predict_fn,\n", " )\n", - " print(\"\\n✅ Trace-derived evaluation complete\")\n", + " print(\"\\n\u2705 Trace-derived evaluation complete\")\n", " print(f\" metrics: {getattr(trace_results, 'metrics', 'N/A')}\")\n", "else:\n", - " print(\"Skipping trace evaluation — no trace data available yet.\")" + " print(\"Skipping trace evaluation \u2014 no trace data available yet.\")" ], "execution_count": null, "outputs": [] diff --git a/stages/refunder_agent.ipynb b/stages/refunder_agent.ipynb index 70e7575..09ef7d0 100644 --- a/stages/refunder_agent.ipynb +++ b/stages/refunder_agent.ipynb @@ -245,13 +245,11 @@ "outputs": [], "source": [ "%sql\n", - "-- USE CATALOG is needed in addition to USE SCHEMA + EXECUTE so the serving\n", - "-- endpoint's auto-generated SP can traverse the catalog to reach the UC\n", - "-- functions at model-load time. Normally granted by the root data stage\n", - "-- (canonical_data/raw_data), but repeated here so this stage is self-\n", - "-- sufficient if run standalone or against a pre-existing catalog.\n", + "-- USE CATALOG is needed in addition to USE SCHEMA + EXECUTE so the\n", + "-- Databricks App service principal can traverse the catalog to reach the\n", + "-- UC functions at inference time.\n", "GRANT USE CATALOG ON CATALOG ${CATALOG} TO `account users`;\n", - "GRANT USE SCHEMA ON SCHEMA ${CATALOG}.ai TO `account users`;" + "GRANT USE SCHEMA ON SCHEMA ${CATALOG}.ai TO `account users`;\n" ] }, { @@ -261,10 +259,10 @@ "outputs": [], "source": [ "%sql\n", - "-- Grant EXECUTE so the serving endpoint SP can call these tools at inference time.\n", + "-- Grant EXECUTE so app callers and the agent app SP can call these tools at inference time.\n", "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_order_details TO `account users`;\n", "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_order_delivery_time TO `account users`;\n", - "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_location_timings TO `account users`;" + "GRANT EXECUTE ON FUNCTION ${CATALOG}.ai.get_location_timings TO `account users`;\n" ] }, { @@ -280,7 +278,12 @@ } }, "source": [ - "#### Model" + "#### App Agent\n", + "\n", + "- Install orchestration dependencies and restart Python for a clean runtime.\n", + "- Capture widget inputs (`CATALOG`, `LLM_MODEL`) and resolve the deterministic Databricks App name.\n", + "- Use `../apps/refund-agent` as the source of truth for the LangGraph refund workflow.\n", + "- Treat `LLM_MODEL` as a Unity AI Gateway endpoint name; no custom-agent LLM calls use legacy model-serving invocation routes.\n" ] }, { @@ -301,8 +304,8 @@ }, "outputs": [], "source": [ - "%pip install -U -qqqq mlflow-skinny[databricks] \"langgraph>=0.3.5,<0.4.0\" databricks-langchain databricks-agents uv\n", - "dbutils.library.restartPython()" + "%pip install -U -qqqq mlflow[databricks] databricks-sdk requests openai\n", + "dbutils.library.restartPython()\n" ] }, { @@ -324,7 +327,15 @@ "outputs": [], "source": [ "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "LLM_MODEL = dbutils.widgets.get(\"LLM_MODEL\")" + "LLM_MODEL = dbutils.widgets.get(\"LLM_MODEL\")\n", + "\n", + "import sys\n", + "sys.path.append('../utils')\n", + "from agent_app_client import refund_agent_app_name\n", + "\n", + "APP_NAME = refund_agent_app_name(CATALOG)\n", + "UC_MODEL_NAME = f\"{CATALOG}.ai.refund_agent_app\"\n", + "print(f\"Refund agent app: {APP_NAME}\")\n" ] }, { @@ -345,7 +356,7 @@ "# set_experiment creates the experiment if it doesn't exist, or activates it if it does.\n", "dev_experiment = mlflow.set_experiment(dev_experiment_name)\n", "dev_experiment_id = dev_experiment.experiment_id\n", - "print(f\"✅ Using dev experiment: {dev_experiment_name} (ID: {dev_experiment_id})\")\n", + "print(f\"\u2705 Using dev experiment: {dev_experiment_name} (ID: {dev_experiment_id})\")\n", "\n", "# Track the experiment in uc_state so `databricks bundle run cleanup` deletes it.\n", "import sys\n", @@ -357,175 +368,79 @@ " \"name\": dev_experiment_name,\n", "}\n", "add(CATALOG, \"experiments\", experiment_data)\n", - "print(f\"✅ Added dev experiment to UC state\")" + "print(f\"\u2705 Added dev experiment to UC state\")" ] }, { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "e1e226ce-f3ee-4522-82fd-8bdf7121f999", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], + "cell_type": "markdown", + "metadata": {}, "source": [ - "import re\n", - "import os\n", - "from IPython.core.magic import register_cell_magic\n", - "\n", - "# Records absolute paths of files written by `%%writefilev`, keyed by the\n", - "# filename argument. Cell 17 below imports `from agent import LLM_ENDPOINT_NAME`.\n", - "#\n", - "# Earlier versions wrote to `os.path.abspath(filename)` (i.e. CWD), but CWD\n", - "# on serverless notebooks is usually a `/Workspace/Users/...` path, and the\n", - "# workspace-files-as-Python-modules feature is not reliably wired through\n", - "# on serverless — the file ends up on disk, `os.path.exists()` returns True,\n", - "# but `import agent` still raises `ModuleNotFoundError`. Pinning the write\n", - "# dir to a regular local-disk path (`/local_disk0/tmp/...`, falling back to\n", - "# `/tmp/...`) sidesteps the workspace-files importer entirely.\n", - "_WRITEFILEV_ABS_PATHS = {}\n", - "\n", - "_WRITEFILEV_DIR = \"/local_disk0/tmp/caspers_writefilev\"\n", - "if not os.path.isdir(\"/local_disk0\"):\n", - " _WRITEFILEV_DIR = \"/tmp/caspers_writefilev\"\n", - "os.makedirs(_WRITEFILEV_DIR, exist_ok=True)\n", - "\n", - "@register_cell_magic\n", - "def writefilev(line, cell):\n", - " \"\"\"\n", - " %%writefilev file.py\n", - " Allows {{var}} substitutions while leaving normal {} intact.\n", - "\n", - " Writes to a stable local-disk path (NOT CWD) so subsequent\n", - " `from import ...` always succeeds, even on serverless\n", - " where CWD is a /Workspace path.\n", - " \"\"\"\n", - " filename = line.strip()\n", - "\n", - " def replacer(match):\n", - " expr = match.group(1)\n", - " return str(eval(expr, globals(), locals()))\n", - "\n", - " content = re.sub(r\"\\{\\{(.*?)\\}\\}\", replacer, cell)\n", + "#### Production Experiment\n", "\n", - " abs_path = os.path.join(_WRITEFILEV_DIR, filename)\n", - " with open(abs_path, \"w\") as f:\n", - " f.write(content)\n", - " _WRITEFILEV_ABS_PATHS[filename] = abs_path\n", - " print(f\"Wrote file with substitutions: {abs_path}\")" + "Create the production MLflow experiment used by the Databricks App runtime for traces." ] }, { "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "63ddd9df-5f1c-42ea-9446-196558fd5b2c", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, + "execution_count": null, + "metadata": {}, "outputs": [], "source": [ - "%%writefilev agent.py\n", - "from typing import Any, Generator, Literal, Optional, Sequence, Union\n", - "\n", "import mlflow\n", - "from databricks_langchain import (\n", - " ChatDatabricks,\n", - " VectorSearchRetrieverTool,\n", - ")\n", - "from langchain_core.tools import tool\n", - "from unitycatalog.ai.core.base import get_uc_function_client\n", - "from langchain_core.language_models import LanguageModelLike\n", - "from langchain_core.runnables import RunnableConfig, RunnableLambda\n", - "from langchain_core.tools import BaseTool\n", - "from langgraph.graph import END, StateGraph\n", - "try:\n", - " from langgraph.graph.graph import CompiledGraph\n", - " from langgraph.graph.state import CompiledStateGraph\n", - "except ImportError: # langgraph >=0.4 restructured these internal modules\n", - " from typing import Any\n", - " CompiledGraph = Any\n", - " CompiledStateGraph = Any\n", - "from langgraph.prebuilt.tool_node import ToolNode\n", - "from mlflow.langchain.chat_agent_langgraph import ChatAgentState, ChatAgentToolNode\n", - "from mlflow.pyfunc import ChatAgent\n", - "from mlflow.types.agent import (\n", - " ChatAgentChunk,\n", - " ChatAgentMessage,\n", - " ChatAgentResponse,\n", - " ChatContext,\n", - ")\n", - "\n", - "import json as _json\n", - "import uuid as _uuid\n", - "from pydantic import BaseModel, ValidationError\n", - "\n", - "mlflow.langchain.autolog()\n", - "\n", - "# Catalog is substituted at notebook write time so the agent module is\n", - "# self-contained and does not depend on widgets at request time.\n", - "CATALOG = \"{{CATALOG}}\"\n", - "\n", - "# Lazy UC function client — constructed on first tool call, NOT at\n", - "# module import / model load time. This is the same pattern used by\n", - "# complaint_agent.py. It matters because the serving endpoint's\n", - "# auto-created service principal is added to `account users` (which\n", - "# holds USE CATALOG / USE SCHEMA / EXECUTE) with a multi-minute\n", - "# propagation delay; if we introspect or execute UC functions at model\n", - "# LOAD time (the way the old UCFunctionToolkit code did), the first\n", - "# deploy fails before propagation completes. Deferring all UC access to\n", - "# request time lets model load complete unconditionally, and by the time\n", - "# real traffic arrives the SP has propagated.\n", - "_uc_client = None\n", "\n", + "prod_experiment_name = f\"/Shared/{CATALOG}_refund_agent_prod\"\n", + "prod_experiment = mlflow.set_experiment(prod_experiment_name)\n", + "prod_experiment_id = prod_experiment.experiment_id\n", + "print(f\"Using prod experiment: {prod_experiment_name} (ID: {prod_experiment_id})\")\n", "\n", - "def _client():\n", - " global _uc_client\n", - " if _uc_client is None:\n", - " _uc_client = get_uc_function_client()\n", - " return _uc_client\n", + "import sys\n", + "sys.path.append('../utils')\n", + "from uc_state import add\n", "\n", + "add(CATALOG, \"experiments\", {\n", + " \"experiment_id\": prod_experiment_id,\n", + " \"name\": prod_experiment_name,\n", + "})\n", + "print(\"Added prod experiment to UC state\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Prompt Registry\n", "\n", - "class RefundDecision(BaseModel):\n", - " refund_usd: float = 0.0\n", - " refund_class: Literal[\"none\", \"partial\", \"full\"] = \"none\"\n", - " reason: str = \"\"\n", + "Seed the prompt registry from the Databricks App source so prompt governance stays with the deployed app code." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import re\n", + "import sys\n", "\n", + "sys.path.append('../utils')\n", + "from prompt_registry import seed_prompt_history\n", "\n", - "############################################\n", - "# Define your LLM endpoint and system prompt\n", - "############################################\n", - "LLM_ENDPOINT_NAME = f\"{{LLM_MODEL}}\"\n", - "llm = ChatDatabricks(endpoint=LLM_ENDPOINT_NAME)\n", + "_agent_py_path = os.path.abspath(\"../apps/refund-agent/agent.py\")\n", + "with open(_agent_py_path) as f:\n", + " _agent_py = f.read()\n", "\n", - "# At endpoint startup, try to load the prompt from the MLflow Prompt\n", - "# Registry so that editing the prompt and bumping the `production` alias\n", - "# takes effect on the next endpoint replica restart — without rebuilding\n", - "# the UC model. Falls back to the literal _FALLBACK_PROMPT below if the\n", - "# registry is empty (first-ever deploy) or unreachable (auth, etc.).\n", - "# The literal is the source of truth: the Prompt Registry cell at the\n", - "# bottom of stages/refunder_agent.ipynb extracts it via regex and\n", - "# registers it as a new version on every deploy.\n", - "PROMPT_URI = f\"prompts:/{CATALOG}.prompts.refund_system@production\"\n", + "_match = re.search(r'_FALLBACK_PROMPT\\s*=\\s*\"\"\"(.*?)\"\"\"', _agent_py, re.DOTALL)\n", + "if not _match:\n", + " raise RuntimeError(\"Could not extract _FALLBACK_PROMPT from refund app source\")\n", "\n", - "_FALLBACK_PROMPT = \"\"\"You are RefundGPT, a CX agent responsible for refund decisions on food delivery orders.\n", + "_REFUND_V1 = (\n", + " \"You are a refund agent for a food delivery service. \"\n", + " \"Given an order_id, decide whether to issue a refund and how much. \"\n", + " \"Return a single-line JSON with `refund_usd` (float), `refund_class` \"\n", + " \"(\\\"none\\\" | \\\"partial\\\" | \\\"full\\\"), and `reason` (short explanation).\"\n", + ")\n", + "_REFUND_V2 = \"\"\"You are RefundGPT, a CX agent responsible for refund decisions on food delivery orders.\n", "\n", " You can call tools to gather the information you need. Start with an `order_id`.\n", "\n", @@ -536,478 +451,56 @@ " 4. Call `get_location_timings(location)` to get the P50/P75/P99 values.\n", " 5. Compare actual delivery time to those percentiles.\n", "\n", - " Refund policy:\n", - "\n", - " A) SLA-based refund (primary path):\n", + " Refund policy (SLA-based):\n", " - If the order arrived AFTER the P75 delivery time: recommend a `partial` or `full` refund based on how late.\n", - " - If the order arrived BEFORE the P75: no SLA-based refund.\n", - "\n", - " B) Goodwill credit (only when complaint context is provided in the user message):\n", - " The user may include lines such as:\n", - " Customer complaint: \"\"\n", - " Complaint category: \n", - " Complaint agent suggested credit: $\n", - " When all three are present AND the SLA path returns \"none\", you MAY ratify the\n", - " complaint agent's goodwill credit:\n", - " - Set `refund_class` = \"partial\"\n", - " - Set `refund_usd` to the suggested credit amount (capped at $10)\n", - " - In `reason`, note that the order was on time per SLA but a goodwill credit\n", - " is being issued in response to the customer's complaint (cite the category).\n", - " Only ratify when the suggested credit is plausible (>$0 and ≤$10) and the\n", - " complaint category is non-empty. Otherwise return \"none\" with an SLA-based reason.\n", - "\n", - " When NO complaint context is provided, behave exactly as the SLA-based path (A) —\n", - " do not invent goodwill credits.\n", + " - If the order arrived BEFORE the P75: no refund.\n", "\n", " Output a single-line JSON with these fields:\n", " - `refund_usd` (float),\n", - " - `refund_class` (\"none\" | \"partial\" | \"full\"),\n", - " - `reason` (short human explanation. If goodwill, say so explicitly.)\n", + " - `refund_class` (\\\"none\\\" | \\\"partial\\\" | \\\"full\\\"),\n", + " - `reason` (short human explanation).\n", "\n", " You must return only the JSON. No extra text or markdown.\"\"\"\n", "\n", - "# Required at endpoint startup: prompt registry URI defaults vary across\n", - "# Model Serving runtimes; without this the 3-part UC name is treated as\n", - "# an opaque string in workspace MLflow and load_prompt raises NotFound.\n", - "mlflow.set_registry_uri(\"databricks-uc\")\n", - "\n", - "try:\n", - " system_prompt = mlflow.genai.load_prompt(PROMPT_URI).template\n", - " print(f\"[refund-agent] Loaded system prompt from {PROMPT_URI}\")\n", - "except Exception as _exc:\n", - " print(\n", - " f\"[refund-agent] Failed to load {PROMPT_URI} \"\n", - " f\"({type(_exc).__name__}: {_exc}); using _FALLBACK_PROMPT.\"\n", - " )\n", - " system_prompt = _FALLBACK_PROMPT\n", - "\n", - "###############################################################################\n", - "## Define tools for your agent, enabling it to retrieve data or take actions\n", - "## beyond text generation.\n", - "##\n", - "## We use plain @tool-decorated wrappers (instead of UCFunctionToolkit) so\n", - "## the UC functions are only introspected/executed at REQUEST time, not at\n", - "## model LOAD time. This is the same lazy pattern used by complaint_agent\n", - "## and support_request_agent, and is required for first-attempt success of\n", - "## agents.deploy() — see _client() above for the full rationale.\n", - "###############################################################################\n", - "\n", - "\n", - "@tool\n", - "def get_order_details(order_id: str) -> str:\n", - " \"\"\"Get the full event history for an order (creation, accepted, dispatched,\n", - " delivered, etc). Use this first to verify the order id is valid and to\n", - " confirm the order was delivered.\"\"\"\n", - " return str(_client().execute_function(\n", - " f\"{CATALOG}.ai.get_order_details\", {\"oid\": order_id}\n", - " ).value)\n", - "\n", - "\n", - "@tool\n", - "def get_order_delivery_time(order_id: str) -> str:\n", - " \"\"\"Return the creation timestamp, delivered timestamp, and elapsed delivery\n", - " duration for an order. Use this to compute the actual delivery time.\"\"\"\n", - " return str(_client().execute_function(\n", - " f\"{CATALOG}.ai.get_order_delivery_time\", {\"oid\": order_id}\n", - " ).value)\n", - "\n", - "\n", - "@tool\n", - "def get_location_timings(location: str) -> str:\n", - " \"\"\"Return the P50/P75/P99 delivery time percentiles for a kitchen location\n", - " so the actual delivery time can be compared to the SLA bands.\"\"\"\n", - " return str(_client().execute_function(\n", - " f\"{CATALOG}.ai.get_location_timings\", {\"loc\": location}\n", - " ).value)\n", - "\n", - "\n", - "tools = [get_order_details, get_order_delivery_time, get_location_timings]\n", - "\n", - "#####################\n", - "## Define agent logic\n", - "#####################\n", - "\n", - "\n", - "def create_tool_calling_agent(\n", - " model: LanguageModelLike,\n", - " tools: Union[Sequence[BaseTool], ToolNode],\n", - " system_prompt: Optional[str] = None,\n", - ") -> CompiledGraph:\n", - " model = model.bind_tools(tools)\n", - "\n", - " # Define the function that determines which node to go to\n", - " def should_continue(state: ChatAgentState):\n", - " messages = state[\"messages\"]\n", - " last_message = messages[-1]\n", - " # If there are function calls, continue. else, end\n", - " if last_message.get(\"tool_calls\"):\n", - " return \"continue\"\n", - " else:\n", - " return \"end\"\n", - "\n", - " if system_prompt:\n", - " preprocessor = RunnableLambda(\n", - " lambda state: [{\"role\": \"system\", \"content\": system_prompt}]\n", - " + state[\"messages\"]\n", - " )\n", - " else:\n", - " preprocessor = RunnableLambda(lambda state: state[\"messages\"])\n", - " model_runnable = preprocessor | model\n", - "\n", - " def call_model(\n", - " state: ChatAgentState,\n", - " config: RunnableConfig,\n", - " ):\n", - " response = model_runnable.invoke(state, config)\n", - "\n", - " return {\"messages\": [response]}\n", - "\n", - " workflow = StateGraph(ChatAgentState)\n", - "\n", - " workflow.add_node(\"agent\", RunnableLambda(call_model))\n", - " workflow.add_node(\"tools\", ChatAgentToolNode(tools))\n", + "_common_tags = {\n", + " \"agent\": \"refund\",\n", + " \"stage\": \"refunder_agent\",\n", + " \"app_name\": APP_NAME,\n", + " \"uc_model\": UC_MODEL_NAME,\n", + " \"consumed_via\": \"mlflow.genai.load_prompt at Databricks App startup\",\n", + "}\n", "\n", - " workflow.set_entry_point(\"agent\")\n", - " workflow.add_conditional_edges(\n", - " \"agent\",\n", - " should_continue,\n", + "seed_prompt_history(\n", + " spark=spark,\n", + " catalog=CATALOG,\n", + " name=\"refund_system\",\n", + " historical=[\n", " {\n", - " \"continue\": \"tools\",\n", - " \"end\": END,\n", + " \"template\": _REFUND_V1,\n", + " \"commit_message\": \"v1: bare-bones refund decisioner, no SLA logic or tool use (demo history seed)\",\n", + " \"tags\": _common_tags,\n", " },\n", - " )\n", - " workflow.add_edge(\"tools\", \"agent\")\n", - "\n", - " return workflow.compile()\n", - "\n", - "\n", - "class LangGraphChatAgent(ChatAgent):\n", - " def __init__(self, agent: CompiledStateGraph):\n", - " self.agent = agent\n", - "\n", - " def predict(\n", - " self,\n", - " messages: list[ChatAgentMessage],\n", - " context: Optional[ChatContext] = None,\n", - " custom_inputs: Optional[dict[str, Any]] = None,\n", - " ) -> ChatAgentResponse:\n", - " request = {\"messages\": self._convert_messages_to_dict(messages)}\n", - "\n", - " result_messages = []\n", - " for event in self.agent.stream(request, stream_mode=\"updates\"):\n", - " for node_data in event.values():\n", - " result_messages.extend(\n", - " ChatAgentMessage(**msg) for msg in node_data.get(\"messages\", [])\n", - " )\n", - " for i in range(len(result_messages) - 1, -1, -1):\n", - " msg = result_messages[i]\n", - " role = msg.role if hasattr(msg, \"role\") else (msg.get(\"role\") if isinstance(msg, dict) else None)\n", - " content = msg.content if hasattr(msg, \"content\") else (msg.get(\"content\", \"\") if isinstance(msg, dict) else \"\")\n", - " if role == \"assistant\" and content:\n", - " try:\n", - " parsed = RefundDecision.model_validate_json(content)\n", - " orig_id = getattr(msg, 'id', None) or str(_uuid.uuid4())\n", - " result_messages[i] = ChatAgentMessage(id=orig_id, role=\"assistant\", content=parsed.model_dump_json())\n", - " except (ValidationError, Exception):\n", - " pass\n", - " break\n", - " return ChatAgentResponse(messages=result_messages)\n", - "\n", - " def predict_stream(\n", - " self,\n", - " messages: list[ChatAgentMessage],\n", - " context: Optional[ChatContext] = None,\n", - " custom_inputs: Optional[dict[str, Any]] = None,\n", - " ) -> Generator[ChatAgentChunk, None, None]:\n", - " request = {\"messages\": self._convert_messages_to_dict(messages)}\n", - " for event in self.agent.stream(request, stream_mode=\"updates\"):\n", - " for node_data in event.values():\n", - " yield from (\n", - " ChatAgentChunk(**{\"delta\": msg}) for msg in node_data[\"messages\"]\n", - " )\n", - "\n", - "\n", - "# Create the agent object, and specify it as the agent object to use when\n", - "# loading the agent back for inference via mlflow.models.set_model()\n", - "agent = create_tool_calling_agent(llm, tools, system_prompt)\n", - "AGENT = LangGraphChatAgent(agent)\n", - "mlflow.models.set_model(AGENT)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "6a909c3e-4fd0-4cdd-98af-93a7ccb3badb", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "import time\n", - "\n", - "sample_order_id = None\n", - "for attempt in range(12):\n", - " rows = spark.sql(f\"\"\"\n", - " SELECT order_id \n", - " FROM {CATALOG}.lakeflow.all_events \n", - " WHERE event_type='delivered'\n", - " LIMIT 1\n", - " \"\"\").collect()\n", - " if rows:\n", - " sample_order_id = rows[0]['order_id']\n", - " break\n", - " print(f\"No delivered events yet (attempt {attempt+1}/12). Waiting 30s for pipeline data...\")\n", - " time.sleep(30)\n", - "\n", - "if not sample_order_id:\n", - " raise RuntimeError(\n", - " f\"No delivered events found in {CATALOG}.lakeflow.all_events after 6 minutes. \"\n", - " \"Ensure the Canonical_Data and Lakeflow pipeline stages completed and processed data.\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "0c672421-c8d0-4252-92d1-576dd01f0f72", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "assert sample_order_id is not None\n", - "print(sample_order_id)" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "38ee737e-0495-4856-a87a-60bac7baf043", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "import mlflow\n", - "import sys\n", - "import os\n", - "\n", - "# Use the absolute path captured by `%%writefilev` (see cell 13). This is\n", - "# stable across CWD drift / kernel restarts; the older `sys.path.append(os.getcwd())`\n", - "# pattern was failing with `ModuleNotFoundError: No module named 'agent'` when\n", - "# CWD changed between the writefile cell and this one.\n", - "_agent_py_path = _WRITEFILEV_ABS_PATHS.get(\"agent.py\")\n", - "if _agent_py_path and os.path.exists(_agent_py_path):\n", - " _agent_dir = os.path.dirname(_agent_py_path)\n", - " print(f\"Importing agent from: {_agent_py_path}\")\n", - "else:\n", - " # Defensive fallback: try the same locations the old code did, plus\n", - " # /databricks/driver which is the typical CWD on classic clusters,\n", - " # and the new pinned local-disk dir used by `%%writefilev`.\n", - " _candidate_dirs = [\n", - " os.getcwd(),\n", - " \"/databricks/driver\",\n", - " \"/local_disk0/tmp/caspers_writefilev\",\n", - " \"/local_disk0/tmp\",\n", - " \"/tmp/caspers_writefilev\",\n", - " \"/tmp\",\n", - " ]\n", - " try:\n", - " _nb_path = dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get()\n", - " _candidate_dirs.append(os.path.dirname(_nb_path))\n", - " except Exception:\n", - " pass\n", - " _agent_dir = next((d for d in _candidate_dirs if os.path.exists(os.path.join(d, \"agent.py\"))), None)\n", - " if _agent_dir is None:\n", - " raise FileNotFoundError(\n", - " f\"agent.py not found. _WRITEFILEV_ABS_PATHS={_WRITEFILEV_ABS_PATHS}, \"\n", - " f\"CWD={os.getcwd()}, candidates tried: {_candidate_dirs}\"\n", - " )\n", - " _agent_py_path = os.path.join(_agent_dir, \"agent.py\")\n", - " print(f\"Importing agent from fallback dir: {_agent_py_path}\")\n", - "\n", - "if _agent_dir not in sys.path:\n", - " sys.path.insert(0, _agent_dir)\n", - "\n", - "# Invalidate any cached `agent` module that may be left over from a prior\n", - "# attempt on the same warehouse (e.g. after a transient failure) — without\n", - "# this, a partially-imported / stale entry in `sys.modules` can cause\n", - "# subsequent `from agent import ...` to fail or return stale symbols.\n", - "sys.modules.pop(\"agent\", None)\n", - "\n", - "from agent import LLM_ENDPOINT_NAME\n", - "from mlflow.models.resources import DatabricksFunction, DatabricksServingEndpoint\n", - "from pkg_resources import get_distribution\n", - "\n", - "# Same `resources=[...]` shape complaint_agent uses — LLM endpoint + the\n", - "# three UC tool functions. Listing the functions lets agents.deploy()\n", - "# auto-grant EXECUTE on them to the endpoint SP. Combined with the lazy\n", - "# @tool wrappers in agent.py (which defer UC introspection until request\n", - "# time), this lets model LOAD complete without needing any UC permission,\n", - "# so the first deploy succeeds even while the SP's `account users`\n", - "# membership is still propagating.\n", - "resources = [\n", - " DatabricksServingEndpoint(endpoint_name=LLM_ENDPOINT_NAME),\n", - " DatabricksFunction(function_name=f\"{CATALOG}.ai.get_order_details\"),\n", - " DatabricksFunction(function_name=f\"{CATALOG}.ai.get_order_delivery_time\"),\n", - " DatabricksFunction(function_name=f\"{CATALOG}.ai.get_location_timings\"),\n", - "]\n", - "\n", - "input_example = {\n", - " \"messages\": [\n", " {\n", - " \"role\": \"user\",\n", - " \"content\": f\"{sample_order_id}\"\n", - " }\n", - " ]\n", - "}\n", - "\n", - "with mlflow.start_run():\n", - " logged_agent_info = mlflow.pyfunc.log_model(\n", - " name=\"agent_v2\",\n", - " python_model=_agent_py_path,\n", - " input_example=input_example,\n", - " resources=resources,\n", - " pip_requirements=[\n", - " f\"databricks-connect=={get_distribution('databricks-connect').version}\",\n", - " f\"mlflow=={get_distribution('mlflow').version}\",\n", - " f\"databricks-langchain=={get_distribution('databricks-langchain').version}\",\n", - " f\"langgraph=={get_distribution('langgraph').version}\",\n", - " ],\n", - " )\n", - "\n", - "mlflow.set_active_model(model_id = logged_agent_info.model_id)" + " \"template\": _REFUND_V2,\n", + " \"commit_message\": \"v2: added tool-calling + SLA-based refund policy (P75 cutoff) (demo history seed)\",\n", + " \"tags\": _common_tags,\n", + " },\n", + " ],\n", + " current={\n", + " \"template\": _match.group(1).strip(),\n", + " \"commit_message\": \"v3 (production): SLA + goodwill credit path, deployed as Databricks App\",\n", + " \"tags\": {**_common_tags, \"deployment_kind\": \"databricks_app\"},\n", + " },\n", + ")\n" ] }, { "cell_type": "markdown", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, - "inputWidgets": {}, - "nuid": "b2f870d6-05d0-478e-8784-98793ba8480c", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "source": [ - "#### log refunder to `UC`" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "273249df-a966-45fc-9ddf-d7138af804b7", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], + "metadata": {}, "source": [ - "from databricks.sdk import WorkspaceClient\n", - "from databricks.sdk.service.serving import EndpointStateReady\n", - "\n", - "mlflow.set_registry_uri(\"databricks-uc\")\n", - "\n", - "UC_MODEL_NAME = f\"{CATALOG}.ai.refunder\"\n", - "endpoint_name = dbutils.widgets.get(\"REFUND_AGENT_ENDPOINT_NAME\")\n", - "\n", - "\n", - "def _endpoint_already_serving(name: str, uc_model_name: str) -> bool:\n", - " \"\"\"Return True iff a serving endpoint is READY and already serving uc_model_name.\n", - "\n", - " Used to short-circuit register_model + agents.deploy on re-runs of this\n", - " stage when the endpoint from a previous run is still healthy — saves ~15\n", - " minutes of cold container build + serving provisioning. To force a fresh\n", - " deploy after editing agent code, delete the endpoint and rerun the stage.\n", - " \"\"\"\n", - " try:\n", - " ep = WorkspaceClient().serving_endpoints.get(name)\n", - " except Exception:\n", - " return False\n", - " if not ep.state or ep.state.ready != EndpointStateReady.READY:\n", - " return False\n", - " cfg = getattr(ep, \"config\", None) or getattr(ep, \"pending_config\", None)\n", - " if not cfg:\n", - " return False\n", - " served = []\n", - " for se in (getattr(cfg, \"served_entities\", None) or []):\n", - " n = getattr(se, \"entity_name\", None)\n", - " if n:\n", - " served.append(n)\n", - " for sm in (getattr(cfg, \"served_models\", None) or []):\n", - " n = getattr(sm, \"model_name\", None)\n", - " if n:\n", - " served.append(n)\n", - " return uc_model_name in served\n", - "\n", - "\n", - "_reuse_endpoint = _endpoint_already_serving(endpoint_name, UC_MODEL_NAME)\n", + "#### Deploy Agent App\n", "\n", - "if _reuse_endpoint:\n", - " print(\n", - " f\"\\u267b\\ufe0f Endpoint {endpoint_name} is already READY and serving {UC_MODEL_NAME}; \"\n", - " f\"skipping register_model + agents.deploy (saves ~15 min). \"\n", - " f\"Delete the endpoint to force a fresh deploy.\"\n", - " )\n", - " uc_registered_model_info = None\n", - "else:\n", - " # register the model to UC\n", - " uc_registered_model_info = mlflow.register_model(\n", - " model_uri=logged_agent_info.model_uri, name=UC_MODEL_NAME\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, - "inputWidgets": {}, - "nuid": "bf715fb3-953c-4516-8d07-79c0d362a99e", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "source": [ - "#### deploy the agent to model serving" + "Create/update the Databricks App, grant the app service principal UC and Gateway access, deploy source, and register the app in uc_state." ] }, { @@ -1016,193 +509,193 @@ "metadata": {}, "outputs": [], "source": [ - "import mlflow\n", - "\n", - "# Create prod experiment for production inference traces.\n", - "# The serving endpoint logs every request as a trace; pointing\n", - "# MLFLOW_EXPERIMENT_ID at this experiment (in the agents.deploy() call\n", - "# below) ensures those traces land somewhere stable and findable, instead\n", - "# of the auto-created /Serving/{endpoint} experiment. Mirrors the pattern\n", - "# in stages/complaint_agent.ipynb so the runbook can link to a single\n", - "# `/Shared/{CATALOG}_refund_agent_prod` URL across deploys.\n", - "prod_experiment_name = f\"/Shared/{CATALOG}_refund_agent_prod\"\n", - "\n", - "# set_experiment creates the experiment if it doesn't exist, or activates it if it does.\n", - "prod_experiment = mlflow.set_experiment(prod_experiment_name)\n", - "prod_experiment_id = prod_experiment.experiment_id\n", - "print(f\"✅ Using prod experiment: {prod_experiment_name} (ID: {prod_experiment_id})\")\n", - "\n", - "# Track the experiment in uc_state so `databricks bundle run cleanup` deletes it.\n", + "import os\n", "import sys\n", + "import time\n", + "\n", "sys.path.append('../utils')\n", + "from agent_app_client import gateway_chat_probe\n", "from uc_state import add\n", "\n", - "experiment_data = {\n", - " \"experiment_id\": prod_experiment_id,\n", - " \"name\": prod_experiment_name,\n", - "}\n", - "add(CATALOG, \"experiments\", experiment_data)\n", - "print(f\"✅ Added prod experiment to UC state\")" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "11c0d2d9-6379-4e2c-84bf-9004dd74b6dd", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "from datetime import timedelta\n", + "from databricks.sdk import WorkspaceClient\n", + "from databricks.sdk.service import catalog as catalog_svc\n", + "from databricks.sdk.service.apps import App, AppDeployment\n", + "from databricks.sdk.service.serving import (\n", + " ServingEndpointAccessControlRequest,\n", + " ServingEndpointPermissionLevel,\n", + ")\n", "\n", - "from databricks import agents\n", + "w = WorkspaceClient()\n", + "source_code_path = os.path.abspath(\"../apps/refund-agent\")\n", + "print(f\"App name: {APP_NAME}\")\n", + "print(f\"App source: {source_code_path}\")\n", + "\n", + "gateway_chat_probe(llm_model=LLM_MODEL, w=w, dbutils=dbutils)\n", + "print(f\"Verified {LLM_MODEL} is queryable through Unity AI Gateway\")\n", + "\n", + "app_yaml_path = os.path.join(source_code_path, \"app.yaml\")\n", + "app_yaml_contents = f\"\"\"command:\n", + " - python\n", + " - start_server.py\n", + "env:\n", + " - name: DATABRICKS_CATALOG\n", + " value: '{CATALOG}'\n", + " - name: LLM_MODEL\n", + " value: '{LLM_MODEL}'\n", + " - name: MLFLOW_EXPERIMENT_ID\n", + " value: '{prod_experiment_id}'\n", + " - name: MLFLOW_TRACKING_URI\n", + " value: 'databricks'\n", + " - name: MLFLOW_REGISTRY_URI\n", + " value: 'databricks-uc'\n", + "\"\"\"\n", + "with open(app_yaml_path, \"w\") as f:\n", + " f.write(app_yaml_contents)\n", + "print(f\"Wrote app runtime config: {app_yaml_path}\")\n", + "\n", + "app_def = App(\n", + " name=APP_NAME,\n", + " description=\"Casper's refund decision agent served by MLflow AgentServer on Databricks Apps.\",\n", + " default_source_code_path=source_code_path,\n", + ")\n", + "try:\n", + " w.apps.get(APP_NAME)\n", + " print(f\"App {APP_NAME} exists, updating...\")\n", + " w.apps.update(APP_NAME, app_def)\n", + "except Exception:\n", + " print(f\"Creating app {APP_NAME}...\")\n", + " w.apps.create(app_def)\n", + "\n", + "\n", + "def _app_state(a):\n", + " cs = getattr(a, \"compute_status\", None)\n", + " s = getattr(cs, \"state\", None) if cs is not None else None\n", + " if s is None:\n", + " s = getattr(a, \"state\", None)\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", + "\n", + "\n", + "deadline = time.time() + 30 * 60\n", + "while True:\n", + " current = w.apps.get(APP_NAME)\n", + " state = _app_state(current)\n", + " print(f\"App {APP_NAME} state: {state}\")\n", + " if state in (\"ACTIVE\", \"RUNNING\", \"READY\"):\n", + " app_status = current\n", + " break\n", + " if state in (\"ERROR\", \"FAILED\"):\n", + " raise RuntimeError(f\"App {APP_NAME} entered failure state: {state}\")\n", + " if time.time() > deadline:\n", + " raise TimeoutError(f\"App {APP_NAME} not ready after 30 minutes (last state: {state})\")\n", + " time.sleep(15)\n", + "\n", + "app_sp_id = (\n", + " getattr(app_status, \"service_principal_client_id\", None)\n", + " or (app_status.as_dict() if hasattr(app_status, \"as_dict\") else {}).get(\"service_principal_client_id\")\n", + ")\n", + "app_uc_principal = (\n", + " getattr(app_status, \"id\", None)\n", + " or app_sp_id\n", + " or (app_status.as_dict() if hasattr(app_status, \"as_dict\") else {}).get(\"id\")\n", + ")\n", + "assert app_sp_id, \"Could not determine app service principal client ID\"\n", + "assert app_uc_principal, \"Could not determine app UC principal\"\n", + "print(f\"App SP ID: {app_sp_id}\")\n", + "\n", + "for full_name, securable_type, privilege in [\n", + " (f\"{CATALOG}\", \"CATALOG\", catalog_svc.Privilege.USE_CATALOG),\n", + " (f\"{CATALOG}.ai\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", + " (f\"{CATALOG}.prompts\", \"SCHEMA\", catalog_svc.Privilege.USE_SCHEMA),\n", + " (f\"{CATALOG}.ai.get_order_details\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + " (f\"{CATALOG}.ai.get_order_delivery_time\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + " (f\"{CATALOG}.ai.get_location_timings\", \"FUNCTION\", catalog_svc.Privilege.EXECUTE),\n", + "]:\n", + " try:\n", + " w.grants.update(\n", + " full_name=full_name,\n", + " securable_type=securable_type,\n", + " changes=[\n", + " catalog_svc.PermissionsChange(\n", + " add=[privilege],\n", + " principal=app_uc_principal,\n", + " )\n", + " ],\n", + " )\n", + " print(f\"Granted {privilege} on {securable_type} {full_name}\")\n", + " except Exception as e:\n", + " print(f\"Could not grant {privilege} on {full_name} to {app_uc_principal}: {e}\")\n", + "\n", + "# LLM_MODEL names a Unity AI Gateway-backed endpoint. We only use the serving\n", + "# endpoint permissions API to grant CAN_QUERY on that Gateway endpoint; no LLM\n", + "# request is routed through legacy model-serving invocation routes.\n", + "llm_endpoint = None\n", + "try:\n", + " llm_endpoint = w.serving_endpoints.get(LLM_MODEL)\n", + "except Exception:\n", + " matches = [ep for ep in w.serving_endpoints.list() if ep.name == LLM_MODEL]\n", + " if matches:\n", + " llm_endpoint = matches[0]\n", + "if llm_endpoint is None or not getattr(llm_endpoint, \"id\", None):\n", + " raise RuntimeError(f\"Could not resolve Gateway endpoint {LLM_MODEL} for permission grant\")\n", + "\n", + "w.serving_endpoints.update_permissions(\n", + " serving_endpoint_id=llm_endpoint.id,\n", + " access_control_list=[\n", + " ServingEndpointAccessControlRequest(\n", + " service_principal_name=app_sp_id,\n", + " permission_level=ServingEndpointPermissionLevel.CAN_QUERY,\n", + " )\n", + " ],\n", + ")\n", + "print(f\"Granted CAN_QUERY on Gateway endpoint {LLM_MODEL} to app SP {app_sp_id}\")\n", "\n", - "if _reuse_endpoint:\n", - " deployment_info = None\n", - " print(f\"\\u2705 Endpoint {endpoint_name} is READY (reused from previous deploy)\")\n", - "else:\n", - " deployment_info = agents.deploy(\n", - " model_name=UC_MODEL_NAME,\n", - " model_version=uc_registered_model_info.version,\n", - " scale_to_zero=False,\n", - " endpoint_name=endpoint_name,\n", - " environment_vars={\"MLFLOW_EXPERIMENT_ID\": str(prod_experiment_id)},\n", + "try:\n", + " w.api_client.do(\n", + " \"PATCH\",\n", + " f\"/api/2.0/permissions/apps/{APP_NAME}\",\n", + " body={\"access_control_list\": [{\"group_name\": \"account users\", \"permission_level\": \"CAN_USE\"}]},\n", " )\n", + " print(\"Granted CAN_USE on app to account users for notebook/job smoke tests\")\n", + "except Exception as e:\n", + " print(f\"Could not grant account users CAN_USE on app {APP_NAME}: {e}\")\n", + "\n", + "add(CATALOG, \"apps\", {\n", + " \"name\": APP_NAME,\n", + " \"url\": getattr(app_status, \"url\", \"\"),\n", + " \"service_principal_client_id\": app_sp_id,\n", + " \"oauth2_app_client_id\": getattr(app_status, \"oauth2_app_client_id\", \"\"),\n", + " \"agent\": 'refund',\n", + "})\n", + "print(\"Registered app in UC state\")\n", + "\n", + "deployment = w.apps.deploy(\n", + " app_name=app_status.name,\n", + " app_deployment=AppDeployment(source_code_path=source_code_path),\n", + ")\n", "\n", - " # Block the stage until the endpoint finishes deploying so the task\n", - " # fails fast (and visibly) when the deploy fails, instead of returning\n", - " # SUCCESS while the container is still building / failing in the\n", - " # background. Matches the wait pattern used by complaint_agent.\n", - " workspace = WorkspaceClient()\n", - " ready_endpoint = workspace.serving_endpoints.wait_get_serving_endpoint_not_updating(\n", - " name=endpoint_name,\n", - " timeout=timedelta(minutes=30),\n", - " )\n", "\n", - " if ready_endpoint.state.ready != EndpointStateReady.READY:\n", - " raise RuntimeError(\n", - " f\"Endpoint {endpoint_name} is {ready_endpoint.state.ready} after deployment; retry or investigate.\"\n", - " )\n", + "def _deploy_state(d):\n", + " st = getattr(d, \"status\", None)\n", + " s = getattr(st, \"state\", None) if st is not None else None\n", + " return getattr(s, \"value\", str(s)) if s is not None else \"\"\n", "\n", - " print(f\"\\u2705 Endpoint {endpoint_name} is READY\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# === Grant UC perms to the endpoint's runtime System Service Principal ===\n", - "#\n", - "# Model serving endpoints (including ones created by agents.deploy()) run\n", - "# their inference container as a workspace-level SCIM SP whose displayName\n", - "# is \"System Service Principal\". These SPs are NOT members of\n", - "# `account users`, so the catalog/schema/function grants made to\n", - "# `account users` elsewhere in the bundle do NOT apply to them.\n", - "#\n", - "# Result without this cell: every fresh agent endpoint fails on its first\n", - "# tool call with PERMISSION_DENIED (\"USE CATALOG\" / \"USE SCHEMA\" / \"EXECUTE\")\n", - "# until somebody manually grants permissions. The cell below discovers\n", - "# the System SPs via SCIM and grants them the perms they need. Idempotent\n", - "# (re-granting an existing privilege in UC is a no-op).\n", - "#\n", - "# Skipped when we reused an existing endpoint (no fresh SP to grant to).\n", - "if deployment_info is not None:\n", - " import sys\n", - " sys.path.append('../utils')\n", - " from agent_runtime_grants import grant_agent_runtime_perms\n", "\n", - " # Pass endpoint_name so the helper also grants to the endpoint\n", - " # creator — that's the actual runtime identity in EMBEDDED_CREDENTIALS\n", - " # mode workspaces (where 'System Service Principal' isn't created and\n", - " # `account users` may be empty at the workspace level).\n", - " grant_agent_runtime_perms(\n", - " spark,\n", - " CATALOG,\n", - " workspace_client=WorkspaceClient(),\n", - " endpoint_name=endpoint_name,\n", - " )\n", - "else:\n", - " print(\"♻ Endpoint reused — skipping runtime SP grants (already applied on first deploy).\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "212d8c67-af4d-48c9-88d8-aba01a22edb0", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "print(deployment_info)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, - "inputWidgets": {}, - "nuid": "a35fccd0-52e3-4efd-b345-83756040e098", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "source": [ - "##### record model in state" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": { - "byteLimit": 2048000, - "rowLimit": 10000 - }, - "inputWidgets": {}, - "nuid": "ff483b29-ed19-4f2b-ba21-f398c0968ff2", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "outputs": [], - "source": [ - "# Also add to UC-state — but only when we actually deployed a new endpoint.\n", - "# On the reuse path the endpoint was already registered by a previous run,\n", - "# so re-adding here would just create a duplicate uc_state row.\n", - "if deployment_info is not None:\n", - " import sys\n", - " sys.path.append('../utils')\n", - " from uc_state import add\n", - "\n", - " add(dbutils.widgets.get(\"CATALOG\"), \"endpoints\", deployment_info)\n", - "else:\n", - " print(\"\\u267b\\ufe0f Endpoint already tracked in uc_state from a previous deploy; skipping add.\")" + "deadline = time.time() + 30 * 60\n", + "while True:\n", + " current_dep = w.apps.get_deployment(app_name=app_status.name, deployment_id=deployment.deployment_id)\n", + " state = _deploy_state(current_dep)\n", + " print(f\"Deployment state: {state}\")\n", + " if state == \"SUCCEEDED\":\n", + " deployment_status = current_dep\n", + " break\n", + " if state in (\"FAILED\", \"STOPPED\"):\n", + " raise RuntimeError(f\"Deployment failed for {app_status.name}: state={state}\")\n", + " if time.time() > deadline:\n", + " raise TimeoutError(f\"Deployment for {app_status.name} not ready after 30 minutes (last state: {state})\")\n", + " time.sleep(10)\n", + "\n", + "print(f\"Refund agent app deployed: {getattr(app_status, 'url', '')}\")\n", + "display(deployment_status)\n" ] }, { @@ -1218,10 +711,10 @@ "\n", "Scorer set (4, all at 100% sampling):\n", "\n", - "- `safety` — built-in `Safety()` LLM judge for harmful or inappropriate content\n", - "- `relevance_to_query` — built-in `RelevanceToQuery()` LLM judge — does the answer address the question\n", - "- `operational_quality` — generic `Guidelines` — concrete data, not a hedge\n", - "- `refund_policy_compliance` — refund-specific `Guidelines` — recommendation matches the policy" + "- `safety` \u2014 built-in `Safety()` LLM judge for harmful or inappropriate content\n", + "- `relevance_to_query` \u2014 built-in `RelevanceToQuery()` LLM judge \u2014 does the answer address the question\n", + "- `operational_quality` \u2014 generic `Guidelines` \u2014 concrete data, not a hedge\n", + "- `refund_policy_compliance` \u2014 refund-specific `Guidelines` \u2014 recommendation matches the policy" ] }, { @@ -1251,11 +744,11 @@ " sampling = ScorerSamplingConfig(sample_rate=sample_rate)\n", " if name in existing:\n", " existing[name].start(sampling_config=sampling)\n", - " print(f\" ↺ {name} — restarted at {sample_rate:.0%} sample rate\")\n", + " print(f\" \u21ba {name} \u2014 restarted at {sample_rate:.0%} sample rate\")\n", " return existing[name]\n", " registered = scorer_obj.register(name=name, experiment_id=prod_experiment_id)\n", " registered.start(sampling_config=sampling)\n", - " print(f\" ✅ {name} — registered + started at {sample_rate:.0%} sample rate\")\n", + " print(f\" \u2705 {name} \u2014 registered + started at {sample_rate:.0%} sample rate\")\n", " return registered\n", "\n", "\n", @@ -1277,12 +770,12 @@ " sample_rate=1.0,\n", ")\n", "\n", - "# Domain — refund-specific policy compliance.\n", + "# Domain \u2014 refund-specific policy compliance.\n", "_register_scorer(\n", " Guidelines(\n", " name=\"refund_policy_compliance\",\n", " guidelines=[\n", - " \"Recommendations must be one of: no refund, partial refund, or full refund — never anything outside this set.\",\n", + " \"Recommendations must be one of: no refund, partial refund, or full refund \u2014 never anything outside this set.\",\n", " \"If a refund is recommended, the response must cite the order_id it applies to.\",\n", " \"If the order is older than the refund window, the response must explicitly mention that the refund window has expired.\",\n", " \"The response must reference the specific reason from the order data (e.g. late delivery, missing item, food quality).\",\n", @@ -1292,122 +785,7 @@ " sample_rate=1.0,\n", ")\n", "\n", - "print(\"✅ Production monitoring enabled — 4 scorers active at 100% sampling\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Prompt Registry\n", - "\n", - "Register the system prompt that was just baked into `agent.py` under\n", - "`{CATALOG}.prompts.refund_system` so it lives alongside the deployed model\n", - "in Unity Catalog. We extract the template directly from `agent.py` rather\n", - "than re-declaring it here, so the registry tracks **exactly** what was\n", - "deployed (no drift possible).\n", - "\n", - "Re-running the stage creates a new prompt version and bumps the\n", - "`production` alias. The schema and all prompts are dropped automatically\n", - "by `DROP CATALOG ... CASCADE` in `destroy.ipynb`, so no extra cleanup\n", - "wiring is needed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "import sys\n", - "\n", - "sys.path.append('../utils')\n", - "from prompt_registry import seed_prompt_history\n", - "\n", - "# Use the absolute path captured by %%writefilev (see cell 13) so this works\n", - "# even when CWD has drifted between cells.\n", - "_agent_py_path = _WRITEFILEV_ABS_PATHS.get(\"agent.py\", \"agent.py\")\n", - "with open(_agent_py_path) as f:\n", - " _agent_py = f.read()\n", - "\n", - "# Extract from _FALLBACK_PROMPT — the source-of-truth literal in agent.py.\n", - "# The deployed agent prefers mlflow.genai.load_prompt() at startup but\n", - "# falls back to this literal if the registry is empty or unreachable.\n", - "_match = re.search(r'_FALLBACK_PROMPT\\s*=\\s*\"\"\"(.*?)\"\"\"', _agent_py, re.DOTALL)\n", - "if not _match:\n", - " raise RuntimeError(\n", - " \"Could not extract _FALLBACK_PROMPT from agent.py. \"\n", - " \"If the prompt block was renamed, update this regex.\"\n", - " )\n", - "\n", - "_uc_version = (\n", - " uc_registered_model_info.version\n", - " if uc_registered_model_info is not None\n", - " else \"reused-endpoint\"\n", - ")\n", - "\n", - "# Two earlier versions of the refund prompt, seeded on first deploy so the\n", - "# Prompt Registry UI shows v1 → v2 → v3 history. These are demo seeds, NOT a\n", - "# real engineering changelog — seed_prompt_history tags each with\n", - "# is_demo_seed=\"true\" so anyone auditing the registry can tell.\n", - "_REFUND_V1 = (\n", - " \"You are a refund agent for a food delivery service. \"\n", - " \"Given an order_id, decide whether to issue a refund and how much. \"\n", - " \"Return a single-line JSON with `refund_usd` (float), `refund_class` \"\n", - " \"(\\\"none\\\" | \\\"partial\\\" | \\\"full\\\"), and `reason` (short explanation).\"\n", - ")\n", - "_REFUND_V2 = \"\"\"You are RefundGPT, a CX agent responsible for refund decisions on food delivery orders.\n", - "\n", - " You can call tools to gather the information you need. Start with an `order_id`.\n", - "\n", - " Instructions:\n", - " 1. Call `order_details(order_id)` first to get event history and confirm the id is valid and the order was delivered.\n", - " 2. Figure out the delivery duration by calling `get_order_delivery_time(order_id)`.\n", - " 3. Extract the location (either directly or from the first event's body).\n", - " 4. Call `get_location_timings(location)` to get the P50/P75/P99 values.\n", - " 5. Compare actual delivery time to those percentiles.\n", - "\n", - " Refund policy (SLA-based):\n", - " - If the order arrived AFTER the P75 delivery time: recommend a `partial` or `full` refund based on how late.\n", - " - If the order arrived BEFORE the P75: no refund.\n", - "\n", - " Output a single-line JSON with these fields:\n", - " - `refund_usd` (float),\n", - " - `refund_class` (\\\"none\\\" | \\\"partial\\\" | \\\"full\\\"),\n", - " - `reason` (short human explanation).\n", - "\n", - " You must return only the JSON. No extra text or markdown.\"\"\"\n", - "\n", - "_common_tags = {\n", - " \"agent\": \"refund\",\n", - " \"stage\": \"refunder_agent\",\n", - " \"uc_model\": UC_MODEL_NAME,\n", - " \"consumed_via\": \"mlflow.genai.load_prompt at endpoint startup\",\n", - "}\n", - "\n", - "seed_prompt_history(\n", - " spark=spark,\n", - " catalog=CATALOG,\n", - " name=\"refund_system\",\n", - " historical=[\n", - " {\n", - " \"template\": _REFUND_V1,\n", - " \"commit_message\": \"v1: bare-bones refund decisioner, no SLA logic or tool use (demo history seed)\",\n", - " \"tags\": _common_tags,\n", - " },\n", - " {\n", - " \"template\": _REFUND_V2,\n", - " \"commit_message\": \"v2: added tool-calling + SLA-based refund policy (P75 cutoff) (demo history seed)\",\n", - " \"tags\": _common_tags,\n", - " },\n", - " ],\n", - " current={\n", - " \"template\": _match.group(1).strip(),\n", - " \"commit_message\": f\"v3 (production): SLA + goodwill credit path — UC model version {_uc_version}\",\n", - " \"tags\": {**_common_tags, \"uc_model_version\": str(_uc_version)},\n", - " },\n", - ")" + "print(\"\u2705 Production monitoring enabled \u2014 4 scorers active at 100% sampling\")" ] } ], @@ -1483,32 +861,6 @@ }, "widgetType": "text" } - }, - "REFUND_AGENT_ENDPOINT_NAME": { - "currentValue": "", - "nuid": "ccf262a5-2c34-435f-a8f3-ea86126c6353", - "typedWidgetInfo": { - "autoCreated": false, - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "validationRegex": null, - "widgetDisplayType": "Text" - }, - "parameterDataType": "String" - }, - "widgetInfo": { - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "autoCreated": false, - "validationRegex": null, - "widgetType": "text" - }, - "widgetType": "text" - } } } }, @@ -1518,4 +870,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/stages/refunder_stream.ipynb b/stages/refunder_stream.ipynb index 0835d3f..f8d17fd 100644 --- a/stages/refunder_stream.ipynb +++ b/stages/refunder_stream.ipynb @@ -72,10 +72,15 @@ "from databricks.sdk import WorkspaceClient\n", "import databricks.sdk.service.jobs as j\n", "import os\n", + "import sys\n", + "\n", + "sys.path.append('../utils')\n", + "from agent_app_client import refund_agent_app_name\n", "\n", "w = WorkspaceClient()\n", "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", + "REFUND_AGENT_APP_NAME = refund_agent_app_name(CATALOG)\n", "\n", "notebook_abs_path = os.path.abspath(\"../jobs/refund_recommender_stream\")\n", "notebook_dbx_path = notebook_abs_path.replace(\n", @@ -85,10 +90,8 @@ "\n", "job_name = f\"Refund Recommender Stream ({CATALOG})\"\n", "\n", - "# timeout_seconds=600 (10 min, matches cron) so a single hung UDF call against\n", - "# a stalled refund-agent endpoint cannot block the whole queue forever.\n", - "# Without this, an orphaned run from an earlier test session can hold the queue\n", - "# slot indefinitely while every subsequent cron tick piles up as QUEUED.\n", + "# timeout_seconds=600 (10 min, matches cron) so a single hung refund-agent app\n", + "# call cannot block the whole queue forever.\n", "task_def = [\n", " j.Task(\n", " task_key=\"refund_recommender_stream\",\n", @@ -97,7 +100,7 @@ " notebook_path=notebook_dbx_path,\n", " base_parameters={\n", " \"CATALOG\": CATALOG,\n", - " \"REFUND_AGENT_ENDPOINT_NAME\": dbutils.widgets.get(\"REFUND_AGENT_ENDPOINT_NAME\"),\n", + " \"REFUND_AGENT_APP_NAME\": REFUND_AGENT_APP_NAME,\n", " },\n", " )\n", " )\n", @@ -109,9 +112,9 @@ ")\n", "\n", "# queue.enabled=False: drop cron triggers if a previous run is still active\n", - "# instead of stacking them up. For an availableNow catch-up stream, dropping\n", + "# instead of stacking them up. For an availableNow catch-up stream, dropping\n", "# is correct: the NEXT tick will pick up whatever rows the previous run didn't\n", - "# get to. Default (enabled=True) creates an unbounded backlog on any hang.\n", + "# get to.\n", "queue_def = j.QueueSettings(enabled=False)\n", "\n", "existing = [jb for jb in w.jobs.list(name=job_name) if jb.settings.name == job_name]\n", @@ -120,7 +123,7 @@ " w.jobs.reset(job_id=job_id, new_settings=j.JobSettings(\n", " name=job_name, tasks=task_def, schedule=schedule_def, queue=queue_def,\n", " ))\n", - " print(f\"♻️ Updated existing job_id={job_id}\")\n", + " print(f\"Updated existing job_id={job_id}\")\n", "else:\n", " job = w.jobs.create(name=job_name, tasks=task_def, schedule=schedule_def, queue=queue_def)\n", " job_id = job.job_id\n", @@ -128,11 +131,10 @@ " sys.path.append('../utils')\n", " from uc_state import add\n", " add(CATALOG, \"jobs\", job)\n", - " print(f\"✅ Created job_id={job_id}\")\n", + " print(f\"Created job_id={job_id}\")\n", "\n", "w.jobs.run_now(job_id=job_id)\n", - "print(f\"🚀 Started run of {job_name}\")\n", - "" + "print(f\"Started run of {job_name} against app {REFUND_AGENT_APP_NAME}\")\n" ], "execution_count": null, "outputs": [] @@ -178,32 +180,6 @@ }, "widgetType": "text" } - }, - "REFUND_AGENT_ENDPOINT_NAME": { - "currentValue": "", - "nuid": "f8358127-303e-4f0e-ac3c-e686c70adce6", - "typedWidgetInfo": { - "autoCreated": false, - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "validationRegex": null, - "widgetDisplayType": "Text" - }, - "parameterDataType": "String" - }, - "widgetInfo": { - "defaultValue": "", - "label": "", - "name": "REFUND_AGENT_ENDPOINT_NAME", - "options": { - "autoCreated": false, - "validationRegex": null, - "widgetType": "text" - }, - "widgetType": "text" - } } } }, diff --git a/utils/agent_app_client.py b/utils/agent_app_client.py new file mode 100644 index 0000000..3af7aaa --- /dev/null +++ b/utils/agent_app_client.py @@ -0,0 +1,224 @@ +"""Helpers for DAIS custom agents deployed as Databricks Apps.""" + +from __future__ import annotations + +import hashlib +import json +import os +import re +from typing import Any, Iterable + +import requests +from databricks.sdk import WorkspaceClient + + +_APP_NAME_MAX_LEN = 30 +_APP_NAME_SAFE = re.compile(r"[^a-z0-9-]+") + + +def _safe_app_name(value: str) -> str: + normalized = _APP_NAME_SAFE.sub("-", value.lower()).strip("-") + normalized = re.sub(r"-+", "-", normalized) + if not normalized: + raise ValueError("App name cannot be empty") + if len(normalized) <= _APP_NAME_MAX_LEN: + return normalized + + digest = hashlib.sha1(normalized.encode("utf-8")).hexdigest()[:6] + prefix_len = _APP_NAME_MAX_LEN - len(digest) - 1 + prefix = normalized[:prefix_len].rstrip("-") + return f"{prefix}-{digest}" + + +def refund_agent_app_name(catalog: str) -> str: + return _safe_app_name(f"refund-agent-{catalog}") + + +def complaint_agent_app_name(catalog: str) -> str: + return _safe_app_name(f"complaint-agent-{catalog}") + + +def get_notebook_token(dbutils: Any) -> str: + ctx = dbutils.notebook.entry_point.getDbutils().notebook().getContext() + token = ctx.apiToken().get() + if not token: + raise RuntimeError("Notebook API token is unavailable") + return token + + +def exchange_notebook_token_for_app_token( + *, + host: str, + notebook_token: str, + app_oauth_client_id: str, + timeout: float = 30, +) -> str: + response = requests.post( + url=f"{host.rstrip('/')}/oidc/v1/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "subject_token": notebook_token, + "subject_token_type": "urn:databricks:params:oauth:token-type:personal-access-token", + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "scope": "all-apis", + "audience": app_oauth_client_id, + }, + timeout=timeout, + ) + response.raise_for_status() + token = response.json().get("access_token") + if not token: + raise RuntimeError("Databricks token exchange returned no access_token") + return token + + +def app_bearer_token( + *, + app_name: str, + w: WorkspaceClient | None = None, + dbutils: Any | None = None, + timeout: float = 30, +) -> str: + w = w or WorkspaceClient() + if dbutils is None: + header = w.config.authenticate().get("Authorization", "") + if not header.startswith("Bearer "): + raise RuntimeError("Databricks OAuth bearer token is unavailable") + return header.removeprefix("Bearer ") + + app = w.apps.get(app_name) + client_id = getattr(app, "oauth2_app_client_id", None) + if not client_id: + raise RuntimeError(f"App {app_name!r} has no oauth2_app_client_id") + return exchange_notebook_token_for_app_token( + host=w.config.host, + notebook_token=get_notebook_token(dbutils), + app_oauth_client_id=client_id, + timeout=timeout, + ) + + +def app_url(app_name: str, w: WorkspaceClient | None = None) -> str: + app = (w or WorkspaceClient()).apps.get(app_name) + url = getattr(app, "url", None) + if not url: + raise RuntimeError(f"App {app_name!r} has no URL") + return url.rstrip("/") + + +def app_request_context( + *, + app_name: str, + w: WorkspaceClient | None = None, + dbutils: Any | None = None, + timeout: float = 30, +) -> dict[str, str]: + w = w or WorkspaceClient() + return { + "app_name": app_name, + "url": app_url(app_name, w=w), + "bearer_token": app_bearer_token( + app_name=app_name, + w=w, + dbutils=dbutils, + timeout=timeout, + ), + } + + +def call_agent_app( + *, + app_name: str, + input_messages: list[dict[str, Any]], + w: WorkspaceClient | None = None, + dbutils: Any | None = None, + timeout: float = 120, + extra_body: dict[str, Any] | None = None, +) -> dict[str, Any]: + w = w or WorkspaceClient() + body: dict[str, Any] = {"input": input_messages} + if extra_body: + body.update(extra_body) + + token = app_bearer_token(app_name=app_name, w=w, dbutils=dbutils, timeout=min(timeout, 30)) + response = requests.post( + url=f"{app_url(app_name, w=w)}/responses", + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + json=body, + timeout=timeout, + ) + response.raise_for_status() + return response.json() + + +def extract_response_text(response: Any) -> str: + if isinstance(response, str): + try: + response = json.loads(response) + except json.JSONDecodeError: + return response + + if hasattr(response, "model_dump"): + response = response.model_dump(mode="json") + elif hasattr(response, "dict"): + response = response.dict() + + if not isinstance(response, dict): + raise TypeError(f"Unsupported response type: {type(response).__name__}") + + direct = response.get("output_text") + if isinstance(direct, str) and direct: + return direct + + for item in _iter_response_items(response.get("output", [])): + content = item.get("content") + if isinstance(content, str) and content: + return content + for content_item in _iter_response_items(content or []): + text = content_item.get("text") + if isinstance(text, str) and text: + return text + + raise ValueError(f"Could not extract output text from response keys: {sorted(response.keys())}") + + +def call_agent_app_text(**kwargs: Any) -> str: + return extract_response_text(call_agent_app(**kwargs)) + + +def gateway_chat_probe( + *, + llm_model: str, + w: WorkspaceClient | None = None, + dbutils: Any | None = None, + timeout: float = 30, +) -> None: + w = w or WorkspaceClient() + if dbutils is not None: + token = get_notebook_token(dbutils) + else: + header = w.config.authenticate().get("Authorization", "") + if not header.startswith("Bearer "): + raise RuntimeError("Databricks OAuth bearer token is unavailable") + token = header.removeprefix("Bearer ") + + response = requests.post( + url=f"{w.config.host.rstrip('/')}/ai-gateway/mlflow/v1/chat/completions", + headers={"Authorization": f"Bearer {token}", "Content-Type": "application/json"}, + json={ + "model": llm_model, + "messages": [{"role": "user", "content": "Say gateway ok."}], + "max_tokens": 8, + }, + timeout=timeout, + ) + response.raise_for_status() + + +def _iter_response_items(value: Any) -> Iterable[dict[str, Any]]: + if isinstance(value, dict): + yield value + elif isinstance(value, list): + for item in value: + if isinstance(item, dict): + yield item From 00c3d7795a230cbb25ddd6c6b4420773f90cd06f Mon Sep 17 00:00:00 2001 From: djliden <7102904+djliden@users.noreply.github.com> Date: Wed, 10 Jun 2026 07:59:37 -0500 Subject: [PATCH 2/3] feat(dais2026): route app-based agents through Unity AI Gateway MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reconciles the app-based agent migration on this branch with the gateway-routing intent from demo/dais-2026 (fe14776): the Refund and Complaint agents stay deployed as Databricks Apps AND route every LLM call through Unity AI Gateway. Verified end-to-end on the sandbox `all` target (25/25 tasks; streams produced 75k+ rows; live + 8-way concurrent calls all 200). Gateway routing (always-on, no model-serving fallback) - New AI_GATEWAY_ENDPOINT_NAME job param (all target), distinct from LLM_MODEL: LLM_MODEL stays the FM name for generators/support; the gateway endpoint name is sent verbatim as the request `model` to /ai-gateway/mlflow/v1. Default databricks-claude-sonnet-4-5 so existing deploys keep working; override per governed endpoint. - Both apps/*/agent.py read GATEWAY_ENDPOINT_NAME from the env (with a transitional LLM_MODEL fallback); static app.yaml + the deploy stages inject it. Correctness - Complaint agent: configure DSPy once at import, apply a fresh-token LM per request via dspy.context() instead of dspy.configure() — avoids the AgentServer worker-thread thread-affinity error under concurrency. - Refund agent's ChatDatabricks(use_ai_gateway=True) already refreshes the bearer per request via DatabricksOpenAI's BearerAuth (no rebuild needed). App-name + warehouse drift fixes (--var catalog vs --params CATALOG) - New OPS_WAREHOUSE_NAME, REFUND_AGENT_APP_NAME, COMPLAINT_AGENT_APP_NAME params baked from ${var.catalog} at deploy time. - utils/agent_app_client.resolve_agent_app_name() prefers the baked param, re-sanitises, falls back to deriving; threaded through the stream jobs, eval stages, agent stages, and ops dashboard. Runbook + docs - SETUP.ipynb step 5 (create gateway endpoint; CAN_QUERY to each agent App SP, not account users) and MLflow.ipynb gateway demo beats, rewritten for the App architecture. - README + AGENTS: --var vs --params drift table and the gateway/LLM_MODEL distinction. Also folds in the app-server simplification already in the working tree (start_server.py uses AgentServer's native /responses; ops dashboard output extraction tightened) — both exercised by the end-to-end run. Co-authored-by: Isaac --- .gitignore | 1 + AGENTS.md | 49 +++++ README.md | 27 ++- apps/caspers-ops-dashboard/app/main.py | 31 +-- apps/complaint-agent/agent.py | 253 +++++++++++++++---------- apps/complaint-agent/app.yaml | 5 +- apps/complaint-agent/start_server.py | 38 +--- apps/refund-agent/agent.py | 48 ++++- apps/refund-agent/app.yaml | 5 +- apps/refund-agent/start_server.py | 45 +---- databricks.yml | 42 ++++ demos/dais2026-runbooks/MLflow.ipynb | 90 ++++++++- demos/dais2026-runbooks/SETUP.ipynb | 152 +++++++++++++++ jobs/complaint_agent_stream.ipynb | 42 ++-- jobs/refund_recommender_stream.ipynb | 35 +--- stages/complaint_agent.ipynb | 80 ++++---- stages/complaint_agent_stream.ipynb | 8 +- stages/complaint_evaluation.ipynb | 8 +- stages/operational_app.ipynb | 25 ++- stages/refund_evaluation.ipynb | 8 +- stages/refunder_agent.ipynb | 67 ++++--- stages/refunder_stream.ipynb | 8 +- utils/agent_app_client.py | 50 +++-- 23 files changed, 739 insertions(+), 378 deletions(-) diff --git a/.gitignore b/.gitignore index e0e301e..7d16da2 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,7 @@ __pycache__ .databricks .claude .cursor +.codex .bundle /.vscode/ node_modules/ diff --git a/AGENTS.md b/AGENTS.md index cfc1c2d..6e41c1e 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -633,6 +633,55 @@ Use full redeploy instead if: - Ensure it's plumbed through all layers - Check `databricks.yml` parameters, stage parameter parsing, and implementation usage +#### 5a. `--var catalog` vs `--params CATALOG` drift +- **Why fragile**: Two separate dials use a catalog name and they can disagree. + - `bundle deploy --var catalog=X` → resolves `${var.catalog}` at deploy + time. Bakes `X` into every DABs-managed resource name (e.g. the `all` + target's `caspers_ops_warehouse` becomes `X-ops-warehouse`), every + AI/BI dashboard name, every dashboard `dataset_catalog`, and every job + parameter `default: ${var.catalog}...` (incl. `CATALOG`, + `REFUND_AGENT_APP_NAME`, `COMPLAINT_AGENT_APP_NAME`, + `OPS_WAREHOUSE_NAME`, `SUPERVISOR_ENDPOINT_NAME`). + - `bundle run caspers --params "CATALOG=Y"` → overrides ONLY the + run-time `CATALOG` widget value inside stage notebooks. Cannot rename + anything DABs already created. +- **Symptom when they disagree**: stages that reconstruct a DABs-managed + resource name from the run-time `CATALOG` widget fail to find it. + Example: deploying with the default and then running + `--params CATALOG=mycatalog` against the `all` target makes + `stages/operational_app.ipynb` fail with + `RuntimeError: Warehouse 'mycatalog-ops-warehouse' not found` because + DABs created `caspersdev-ops-warehouse`. +- **Best practice**: + - When in doubt, pass the same catalog to both: + `bundle deploy -t --var catalog=X` then + `bundle run caspers --params "CATALOG=X"`. + - When adding a stage that needs the name of a DABs-managed resource (or + of an agent App the agent stages deployed), do NOT reconstruct it from + the `CATALOG` widget. Add a dedicated job parameter with a + `${var.catalog}-...` default in `databricks.yml` (this is how + `OPS_WAREHOUSE_NAME`, `REFUND_AGENT_APP_NAME` and + `COMPLAINT_AGENT_APP_NAME` are wired), then read that parameter via + `dbutils.widgets.get(...)` in the stage. The deploy-time value rides + through the job parameter into the run-time widget, so the two dials + physically cannot disagree. + - The agent App names additionally pass through + `utils/agent_app_client.resolve_agent_app_name(...)`, which prefers the + baked param and re-sanitises to the Databricks Apps name rules; the + `all`-target param block in `databricks.yml` (with its + deploy-time-vs-run-time comment) is the canonical example of this pattern. + +#### 5b. Unity AI Gateway endpoint name (`AI_GATEWAY_ENDPOINT_NAME`) +- **Distinct from `LLM_MODEL`**: `LLM_MODEL` is the foundation model the + generators / support agent call directly. `AI_GATEWAY_ENDPOINT_NAME` is the + governed Unity AI Gateway endpoint the Refund + Complaint **App** agents send + every LLM call through (gateway-always-on, no model-serving fallback). It is + sent verbatim as the request `model` to `/ai-gateway/mlflow/v1`. +- **Manual setup**: the v2 Beta gateway is UI-created only and has no + permissions API, so `CAN_QUERY` must be granted to **each agent App's service + principal** by hand (the App SP is not in `account users`). See step 5 of + `demos/dais2026-runbooks/SETUP.ipynb`. + ### 6. Resource Dependencies - **Why fragile**: Stages create resources that others depend on (endpoints, tables, etc.) - **When touching**: Creation/deletion order, stage dependencies diff --git a/README.md b/README.md index d283269..583ae40 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,37 @@ Available targets: | `free` | Data generation, Lakeflow pipeline (Free Edition compatible) | | `all` | Everything end-to-end: refund + complaints + Operational Dashboard (3 Genies + 6 Knowledge Assistants + Multi-Agent Supervisor + Lakebase-backed FastAPI app) | -Optionally specify a catalog (default: `caspersdev`): +Optionally specify a catalog (default: `caspersdev`). There are **two** dials +that take a catalog name and they must agree: + +| Dial | When | What it controls | +|---|---|---| +| `bundle deploy --var catalog=` | deploy time | the catalog baked into every DABs-managed resource — the `all` target's `caspers_ops_warehouse` SQL warehouse, AI/BI dashboard names, dashboard `dataset_catalog`, and the *default* value of every job parameter that uses `${var.catalog}` (including `CATALOG`, `REFUND_AGENT_APP_NAME`, `COMPLAINT_AGENT_APP_NAME`, `OPS_WAREHOUSE_NAME`, etc.) | +| `bundle run caspers --params "CATALOG="` | run time | only the value of the `CATALOG` widget inside stage notebooks. Cannot rename anything DABs already created. | + +If they disagree (e.g. `bundle deploy -t all` with the default + `bundle run +--params CATALOG=mycatalog`), the `all` target will fail at the +`Operational_App` stage because the warehouse DABs created (`caspersdev-ops-warehouse`) +is not what the stage looks up (`mycatalog-ops-warehouse`). The fix is to +pass the same catalog to both: ```bash +databricks bundle deploy -t all --var catalog=mycatalog databricks bundle run caspers --params "CATALOG=mycatalog" ``` +For targets other than `all` (no DABs-owned warehouse/dashboards), +`--params CATALOG=mycatalog` alone usually works, but passing both keeps the +deploy-time and run-time catalogs in sync and is the safer habit. + +> **Agents on the `all` target run as Databricks Apps through Unity AI +> Gateway.** The Refund and Complaint agents are deployed as Apps +> (`apps/refund-agent`, `apps/complaint-agent`) and route every LLM call +> through a UI-created Unity AI Gateway endpoint (`AI_GATEWAY_ENDPOINT_NAME`). +> The gateway and its `CAN_QUERY` grant to each agent App's service principal +> are manual one-time setup — see step 5 of +> `demos/dais2026-runbooks/SETUP.ipynb`. + ## Clean Up ```bash diff --git a/apps/caspers-ops-dashboard/app/main.py b/apps/caspers-ops-dashboard/app/main.py index d954144..301a69c 100644 --- a/apps/caspers-ops-dashboard/app/main.py +++ b/apps/caspers-ops-dashboard/app/main.py @@ -1076,32 +1076,11 @@ def _call_agent_app(app_name: str, configured_url: str, payload: dict) -> dict: def _extract_agent_output_text(data: dict) -> str: - output_text = data.get("output_text") - if isinstance(output_text, str) and output_text: - return output_text - - output = data.get("output") or [] - if isinstance(output, dict): - output = [output] - for out in output: - if not isinstance(out, dict): - continue - content = out.get("content") - if isinstance(content, str) and content: - return content - if isinstance(content, dict): - content = [content] - if isinstance(content, list): - for part in content: - if isinstance(part, dict): - text = part.get("text") - if isinstance(text, str) and text: - return text - - choices = data.get("choices") or [] - if choices: - return (choices[0].get("message") or {}).get("content", "") - return "" + try: + text = data["output"][0]["content"][0]["text"] + except (KeyError, IndexError, TypeError): + return "" + return text if isinstance(text, str) else "" def _build_refund_user_message(req: "RefundRequest") -> str: diff --git a/apps/complaint-agent/agent.py b/apps/complaint-agent/agent.py index 94c4671..4ef5c4d 100644 --- a/apps/complaint-agent/agent.py +++ b/apps/complaint-agent/agent.py @@ -1,4 +1,6 @@ +import json import os +import re import uuid import warnings from typing import Literal, Optional @@ -20,9 +22,38 @@ mlflow.dspy.autolog(log_traces=True) CATALOG = os.environ["DATABRICKS_CATALOG"] -LLM_MODEL = os.environ["LLM_MODEL"] -HOST = (os.environ.get("DATABRICKS_HOST") or Config().host).rstrip("/") -GATEWAY_BASE_URL = f"{HOST}/ai-gateway/mlflow/v1" +# Unity AI Gateway endpoint that ALL of this agent's LLM calls route through. +# Gateway-always-on: there is no model-serving fallback. Sent verbatim as the +# `model` field to /ai-gateway/mlflow/v1, so it must name a queryable +# gateway route (its CAN_QUERY is granted to the App SP manually — see runbook). +# Falls back to LLM_MODEL during the migration to the dedicated +# AI_GATEWAY_ENDPOINT_NAME param so the App works whether the deploy stage +# injects the old or new env var. +GATEWAY_ENDPOINT_NAME = os.environ.get("AI_GATEWAY_ENDPOINT_NAME") or os.environ["LLM_MODEL"] +COMPLAINT_TRIAGE_PROMPT = """Decision framework: +- Use exactly one complaint_category: delivery_delay, missing_items, food_quality, service_issue, billing, or other. +- Use decision "suggest_credit" only when a concrete credit amount is appropriate. Otherwise use "escalate". +- Delivery delays: if actual delivery is below P75, credit_amount should be 0.0 with low confidence; P75-P99 suggests about 15% of order total; above P99 suggests about 25%. +- Missing items: use item prices when the claimed item appears in the order; otherwise escalate. +- Food quality: minor issues can suggest about 20%; severe or health/safety issues should escalate urgently. +- For suggest_credit, credit_amount and confidence are required and priority must be null. +- For escalate, priority is required and credit_amount/confidence must be null. +- Rationale must cite specific evidence and stay under 150 words. + +Return only this JSON shape: +{"order_id":"","complaint_category":"delivery_delay|missing_items|food_quality|service_issue|billing|other","decision":"suggest_credit|escalate","credit_amount":0.0,"confidence":"high|medium|low","priority":null,"rationale":"..."}""" + + +def _workspace_host() -> str: + host = (os.environ.get("DATABRICKS_HOST") or Config().host or "").rstrip("/") + if not host: + raise RuntimeError("Databricks workspace host is unavailable") + if not host.startswith(("http://", "https://")): + host = f"https://{host}" + return host + + +GATEWAY_BASE_URL = f"{_workspace_host()}/ai-gateway/mlflow/v1" def _auth_header() -> str: @@ -36,28 +67,45 @@ def _token() -> str: return _auth_header().removeprefix("Bearer ") -def _configure_dspy() -> None: - lm = dspy.LM( - f"openai/{LLM_MODEL}", +def _build_lm() -> dspy.LM: + """Build a DSPy LM bound to the AI Gateway with a *fresh* bearer token. + + Rebuilt on every request (see `_run_triage`) so a rotated OAuth M2M token + is picked up immediately — dspy.LM / litellm bake `api_key` in at + construction time and expose no callable-key hook. + """ + return dspy.LM( + f"openai/{GATEWAY_ENDPOINT_NAME}", api_base=GATEWAY_BASE_URL, api_key=_token(), - max_tokens=2000, - num_retries=20, + max_tokens=1000, + num_retries=3, cache=False, ) - dspy.configure(lm=lm) def _validate_gateway_endpoint() -> None: client = OpenAI(api_key=_token(), base_url=GATEWAY_BASE_URL, timeout=30) client.chat.completions.create( - model=LLM_MODEL, + model=GATEWAY_ENDPOINT_NAME, messages=[{"role": "user", "content": "Say gateway ok."}], max_tokens=8, ) _validate_gateway_endpoint() + +# Configure DSPy's global settings ONCE, on the import (main) thread. The +# MLflow AgentServer dispatches the request handler on FastAPI worker threads, +# and `dspy.configure()` enforces thread-affinity — only the thread that first +# configured it may reconfigure. Calling `dspy.configure()` from inside the +# request handler therefore raises "dspy.settings can only be changed by the +# thread that initially configured it" on the first request that lands on a +# different worker thread. We configure once here and apply a fresh-token LM +# per request via the thread-safe `dspy.context(...)` override in +# `_run_triage`. This base LM is NOT what serves requests. +dspy.configure(lm=_build_lm(), adapter=dspy.ChatAdapter(use_json_adapter_fallback=False)) + _uc_client = None @@ -128,57 +176,6 @@ def parse_priority(cls, v): return v -class ComplaintTriage(dspy.Signature): - """Analyze customer complaints for Casper's Kitchens and recommend triage actions. - - Process: - 1. Extract order_id from complaint - 2. Use get_order_overview(order_id) for order details and items - 3. Use get_order_timing(order_id) for delivery timing - 4. For delays, use get_location_timings(location) for percentile benchmarks - 5. Make data-backed decision - - Decision Framework: - - SUGGEST_CREDIT (with credit_amount and confidence): - - Delivery delays: Compare actual delivery time to location percentiles - * P99: Suggest 25% of order total (high confidence) - - Missing items: Use actual item prices from order data when available - * Verify claimed item exists in order (affects confidence) - * Use real costs from order data, or estimate $8-12 per item if unavailable - - Food quality: 20-40% of order total based on severity - * Minor issues (slightly cold, minor preparation issue): 20% (medium confidence) - * Major issues (completely inedible, wrong preparation, health concern): 40% (high confidence) - * Vague complaints ("bad", "gross"): escalate instead - - ESCALATE (with priority): - - priority="standard": Vague complaints, missing data, billing issues, service complaints - - priority="urgent": Legal threats, health/safety concerns, suspected fraud, abusive language - - Output Requirements: - - For suggest_credit: credit_amount is REQUIRED and must be a number (can be 0.0 if no credit warranted), confidence is REQUIRED, priority must be null - - For escalate: priority is REQUIRED, credit_amount and confidence must be null - - complaint_category: Choose EXACTLY ONE category (the primary one) - - Rationale must cite specific evidence (delivery times, percentiles, item verification, order total) - - Rationale should be detailed but under 150 words - - Round credit amounts to nearest $0.50 - - Confidence: high (strong data), medium (reasonable inference), low (weak/contradictory) - """ - - complaint: str = dspy.InputField(desc="Customer complaint text") - order_id: str = dspy.OutputField(desc="Extracted order ID") - complaint_category: str = dspy.OutputField( - desc="EXACTLY ONE category: delivery_delay, missing_items, food_quality, service_issue, billing, or other" - ) - decision: str = dspy.OutputField(desc="EXACTLY ONE: suggest_credit or escalate") - credit_amount: str = dspy.OutputField(desc="If suggest_credit: a number. If escalate: null") - confidence: str = dspy.OutputField(desc="If suggest_credit: high, medium, or low. If escalate: null") - priority: str = dspy.OutputField(desc="If escalate: standard or urgent. If suggest_credit: null") - rationale: str = dspy.OutputField(desc="Data-focused justification citing specific evidence") - - def get_order_overview(order_id: str) -> str: """Get order details including items, location, and customer info.""" result = _client().execute_function(f"{CATALOG}.ai.get_order_overview", {"oid": order_id}) @@ -197,40 +194,105 @@ def get_location_timings(location: str) -> str: return str(result.value) -class ComplaintTriageModule(dspy.Module): - def __init__(self): - super().__init__() - self.react = dspy.ReAct( - signature=ComplaintTriage, - tools=[get_order_overview, get_order_timing, get_location_timings], - max_iters=10, - ) - - def forward(self, complaint: str, max_retries: int = 2) -> ComplaintResponse: - for attempt in range(max_retries + 1): +_ORDER_ID_RE = re.compile(r"\border\s*id\s*[:#-]?\s*([A-Za-z0-9]{6}(?:-L\d+)?)\b", re.IGNORECASE) +_FALLBACK_ID_RE = re.compile(r"\b[A-Z0-9]{6}(?:-L\d+)?\b") + + +def _extract_order_id(text: str) -> str: + match = _ORDER_ID_RE.search(text) + if match: + return match.group(1).upper() + match = _FALLBACK_ID_RE.search(text.upper()) + if match: + return match.group(0) + raise ValueError("No order_id found in complaint") + + +def _extract_location(order_overview: str) -> Optional[str]: + match = re.search(r"['\"]location['\"]\s*[:=]\s*['\"]([^'\"]+)['\"]", order_overview) + if match: + return match.group(1) + for location in ("San Francisco", "Silicon Valley", "Bellevue", "Chicago"): + if location.lower() in order_overview.lower(): + return location + return None + + +def _lm_text(outputs) -> str: + if isinstance(outputs, str): + return outputs + if isinstance(outputs, list) and outputs: + first = outputs[0] + if isinstance(first, str): + return first + if isinstance(first, dict): + for key in ("text", "content", "answer", "response"): + if key in first and first[key]: + return str(first[key]) + return json.dumps(first) + return str(outputs) + + +def _parse_response(text: str) -> ComplaintResponse: + start = text.find("{") + end = text.rfind("}") + if start < 0 or end < start: + raise ValueError(f"Complaint agent returned no JSON object: {text}") + payload = json.loads(text[start : end + 1]) + return ComplaintResponse.model_validate(payload) + + +def _triage_prompt( + complaint: str, + order_id: str, + order_overview: str, + order_timing: str, + location_timings: str, +) -> str: + return f"""Analyze this Casper's Kitchens customer complaint and return only JSON. + +Customer complaint: +{complaint} + +Order id: +{order_id} + +Order overview from Unity Catalog: +{order_overview} + +Order timing from Unity Catalog: +{order_timing} + +Location delivery percentiles from Unity Catalog: +{location_timings or "Unavailable"} + +{COMPLAINT_TRIAGE_PROMPT.replace("", order_id)}""" + + +def _run_triage(complaint: str) -> ComplaintResponse: + order_id = _extract_order_id(complaint) + order_overview = get_order_overview(order_id) + order_timing = get_order_timing(order_id) + location = _extract_location(order_overview) + location_timings = get_location_timings(location) if location else "" + lm = _build_lm() + + prompt = _triage_prompt(complaint, order_id, order_overview, order_timing, location_timings) + last_text = "" + # `dspy.context(...)` is the thread-safe, per-request settings override — + # safe to call from the AgentServer worker thread, unlike `dspy.configure()` + # (see the module-load comment above). + with dspy.context(lm=lm): + for attempt in range(2): + outputs = lm(messages=[{"role": "user", "content": prompt}]) + last_text = _lm_text(outputs) try: - result = self.react(complaint=complaint) - credit_amount = None - if result.credit_amount and result.credit_amount.lower() != "null": - try: - credit_amount = float(result.credit_amount) - except (ValueError, TypeError): - credit_amount = None - if result.decision == "suggest_credit" and credit_amount is None: - credit_amount = 0.0 - return ComplaintResponse( - order_id=result.order_id, - complaint_category=result.complaint_category, - decision=result.decision, - credit_amount=credit_amount, - confidence=result.confidence, - priority=result.priority, - rationale=result.rationale, - ) - except (ValidationError, ValueError): - if attempt >= max_retries: - raise - raise RuntimeError("Complaint triage failed after retries") + return _parse_response(last_text) + except (json.JSONDecodeError, ValidationError, ValueError) as exc: + if attempt: + raise ValueError(f"Invalid complaint agent JSON: {last_text}") from exc + prompt += f"\n\nYour previous response was invalid: {last_text}\nReturn only valid JSON with the required shape." + raise RuntimeError("Complaint triage failed") def _msg_to_dict(msg) -> dict: @@ -254,7 +316,6 @@ def _text_output(text: str, item_id: str | None = None) -> dict: @invoke() def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: - _configure_dspy() complaint = None for msg in request.input: msg_dict = _msg_to_dict(msg) @@ -264,7 +325,7 @@ def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: if not complaint: raise ValueError("No user message found in request") - result = ComplaintTriageModule()(complaint=complaint) + result = _run_triage(complaint) return ResponsesAgentResponse( output=[_text_output(result.model_dump_json())], custom_outputs=request.custom_inputs, diff --git a/apps/complaint-agent/app.yaml b/apps/complaint-agent/app.yaml index ec46f8c..307d59b 100644 --- a/apps/complaint-agent/app.yaml +++ b/apps/complaint-agent/app.yaml @@ -4,7 +4,10 @@ command: env: - name: DATABRICKS_CATALOG value: caspersdev - - name: LLM_MODEL + # Unity AI Gateway endpoint the agent routes every LLM call through. + # The deploy stage overwrites this app.yaml with the per-catalog value; + # this static default is for local/dev runs only. + - name: AI_GATEWAY_ENDPOINT_NAME value: databricks-claude-sonnet-4-5 - name: MLFLOW_TRACKING_URI value: databricks diff --git a/apps/complaint-agent/start_server.py b/apps/complaint-agent/start_server.py index 5bc2591..8db2701 100644 --- a/apps/complaint-agent/start_server.py +++ b/apps/complaint-agent/start_server.py @@ -1,45 +1,19 @@ -import inspect import os +import sys import agent # noqa: F401 - registers @invoke with MLflow AgentServer -from fastapi import Request -from fastapi.responses import JSONResponse -from mlflow.genai.agent_server import AgentServer, setup_mlflow_git_based_version_tracking -from mlflow.types.responses import ResponsesAgentRequest +from mlflow.genai.agent_server import AgentServer agent_server = AgentServer("ResponsesAgent") app = agent_server.app -setup_mlflow_git_based_version_tracking() - - -@app.post("/responses") -@app.post("/api/responses") -async def responses(request: Request): - from mlflow.genai.agent_server import get_invoke_function - - body = await request.json() - if body.get("stream"): - return JSONResponse( - status_code=400, - content={"error": "stream=true is not supported by this agent app"}, - ) - - invoke_fn = get_invoke_function() - result = invoke_fn(ResponsesAgentRequest(**body)) - if inspect.isawaitable(result): - result = await result - return JSONResponse(content=result.model_dump(mode="json")) - def main(): - port = int(os.environ.get("DATABRICKS_APP_PORT", "8000")) - agent_server.run( - app_import_string="start_server:app", - host="0.0.0.0", - port=port, - ) + app_port = os.environ.get("DATABRICKS_APP_PORT") + if app_port and "--port" not in sys.argv: + sys.argv.extend(["--port", app_port]) + agent_server.run(app_import_string="start_server:app") if __name__ == "__main__": diff --git a/apps/refund-agent/agent.py b/apps/refund-agent/agent.py index 64b152a..da74cc7 100644 --- a/apps/refund-agent/agent.py +++ b/apps/refund-agent/agent.py @@ -29,7 +29,14 @@ mlflow.langchain.autolog() CATALOG = os.environ["DATABRICKS_CATALOG"] -LLM_MODEL = os.environ["LLM_MODEL"] +# Unity AI Gateway endpoint that ALL of this agent's LLM calls route through. +# Gateway-always-on: there is no model-serving fallback. Sent verbatim as the +# `model` field to /ai-gateway/mlflow/v1, so it must name a queryable +# gateway route (its CAN_QUERY is granted to the App SP manually — see runbook). +# Falls back to LLM_MODEL during the migration to the dedicated +# AI_GATEWAY_ENDPOINT_NAME param so the App works whether the deploy stage +# injects the old or new env var. +GATEWAY_ENDPOINT_NAME = os.environ.get("AI_GATEWAY_ENDPOINT_NAME") or os.environ["LLM_MODEL"] PROMPT_URI = f"prompts:/{CATALOG}.prompts.refund_system@production" _FALLBACK_PROMPT = """You are RefundGPT, a CX agent responsible for refund decisions on food delivery orders. @@ -81,15 +88,23 @@ def _auth_header() -> str: return header +def _workspace_host() -> str: + host = (os.environ.get("DATABRICKS_HOST") or Config().host or "").rstrip("/") + if not host: + raise RuntimeError("Databricks workspace host is unavailable") + if not host.startswith(("http://", "https://")): + host = f"https://{host}" + return host + + def _validate_gateway_endpoint() -> None: - host = (os.environ.get("DATABRICKS_HOST") or Config().host).rstrip("/") client = OpenAI( api_key=_auth_header().removeprefix("Bearer "), - base_url=f"{host}/ai-gateway/mlflow/v1", + base_url=f"{_workspace_host()}/ai-gateway/mlflow/v1", timeout=30, ) client.chat.completions.create( - model=LLM_MODEL, + model=GATEWAY_ENDPOINT_NAME, messages=[{"role": "user", "content": "Say gateway ok."}], max_tokens=8, ) @@ -119,6 +134,22 @@ class RefundDecision(BaseModel): reason: str = "" +def _parse_refund_decision(text: str) -> RefundDecision | None: + try: + return RefundDecision.model_validate_json(text) + except (ValidationError, ValueError, TypeError): + pass + + start = text.find("{") + end = text.rfind("}") + if start < 0 or end < start: + return None + try: + return RefundDecision.model_validate_json(text[start : end + 1]) + except (ValidationError, ValueError, TypeError): + return None + + @tool def get_order_details(order_id: str) -> str: """Get the full event history for an order.""" @@ -184,7 +215,7 @@ def call_model(state: ChatAgentState, config: RunnableConfig): return workflow.compile() -LLM = ChatDatabricks(model=LLM_MODEL, use_ai_gateway=True) +LLM = ChatDatabricks(model=GATEWAY_ENDPOINT_NAME, use_ai_gateway=True) AGENT: CompiledStateGraph = create_tool_calling_agent(LLM, TOOLS, SYSTEM_PROMPT) @@ -217,11 +248,10 @@ def _run_agent(messages: list[dict]) -> str: role = msg.get("role") if isinstance(msg, dict) else getattr(msg, "role", None) content = msg.get("content", "") if isinstance(msg, dict) else getattr(msg, "content", "") if role == "assistant" and content: - try: - parsed = RefundDecision.model_validate_json(content) + parsed = _parse_refund_decision(str(content)) + if parsed: return parsed.model_dump_json() - except (ValidationError, ValueError, TypeError): - return str(content) + return str(content) raise RuntimeError("Refund agent produced no assistant message") diff --git a/apps/refund-agent/app.yaml b/apps/refund-agent/app.yaml index ec46f8c..307d59b 100644 --- a/apps/refund-agent/app.yaml +++ b/apps/refund-agent/app.yaml @@ -4,7 +4,10 @@ command: env: - name: DATABRICKS_CATALOG value: caspersdev - - name: LLM_MODEL + # Unity AI Gateway endpoint the agent routes every LLM call through. + # The deploy stage overwrites this app.yaml with the per-catalog value; + # this static default is for local/dev runs only. + - name: AI_GATEWAY_ENDPOINT_NAME value: databricks-claude-sonnet-4-5 - name: MLFLOW_TRACKING_URI value: databricks diff --git a/apps/refund-agent/start_server.py b/apps/refund-agent/start_server.py index c0dc438..8db2701 100644 --- a/apps/refund-agent/start_server.py +++ b/apps/refund-agent/start_server.py @@ -1,52 +1,19 @@ -import inspect -import json import os +import sys import agent # noqa: F401 - registers @invoke with MLflow AgentServer -from fastapi import Request -from fastapi.responses import JSONResponse -from mlflow.genai.agent_server import AgentServer, setup_mlflow_git_based_version_tracking -from mlflow.types.responses import ResponsesAgentRequest +from mlflow.genai.agent_server import AgentServer agent_server = AgentServer("ResponsesAgent") app = agent_server.app -setup_mlflow_git_based_version_tracking() - - -@app.post("/responses") -@app.post("/api/responses") -async def responses(request: Request): - """Databricks Apps agent-compatible Responses API alias. - - MLflow AgentServer serves /invocations locally. Databricks Apps agent - clients use /responses, so expose the same registered invoke function on - that route too. - """ - from mlflow.genai.agent_server import get_invoke_function - - body = await request.json() - if body.get("stream"): - return JSONResponse( - status_code=400, - content={"error": "stream=true is not supported by this agent app"}, - ) - - invoke_fn = get_invoke_function() - result = invoke_fn(ResponsesAgentRequest(**body)) - if inspect.isawaitable(result): - result = await result - return JSONResponse(content=result.model_dump(mode="json")) - def main(): - port = int(os.environ.get("DATABRICKS_APP_PORT", "8000")) - agent_server.run( - app_import_string="start_server:app", - host="0.0.0.0", - port=port, - ) + app_port = os.environ.get("DATABRICKS_APP_PORT") + if app_port and "--port" not in sys.argv: + sys.argv.extend(["--port", app_port]) + agent_server.run(app_import_string="start_server:app") if __name__ == "__main__": diff --git a/databricks.yml b/databricks.yml index d94aa22..51a419e 100644 --- a/databricks.yml +++ b/databricks.yml @@ -23,9 +23,12 @@ sync: - .claude/** - .databricks/** - .venv/** + - "**/.venv/**" - .bundle/** - __pycache__/** + - "**/__pycache__/**" - "*.pyc" + - "**/*.pyc" # Repo content we never want shipped - ./data/universe/** - images/** @@ -597,6 +600,45 @@ targets: # Operational Dashboard — Multi-Agent Supervisor endpoint (user-supplied) - name: SUPERVISOR_ENDPOINT_NAME default: ${var.catalog}_operational_supervisor + # Unity AI Gateway endpoint name. DISTINCT from LLM_MODEL above: + # LLM_MODEL is the foundation model the generators / support agent + # call directly; AI_GATEWAY_ENDPOINT_NAME is the governed gateway + # route the Refund + Complaint App agents send every LLM call + # through (gateway-always-on, no model-serving fallback) so PII + # guardrails / inference tables / usage tracking apply centrally. + # Sent verbatim as the request `model` to + # /ai-gateway/mlflow/v1. The v2 Beta gateway is UI-created + # only and CAN_QUERY must be granted to each agent App's SP by hand + # — see the runbook. Default matches the historical value so + # existing deploys keep working; override with + # `--params AI_GATEWAY_ENDPOINT_NAME=`. + - name: AI_GATEWAY_ENDPOINT_NAME + default: databricks-claude-sonnet-4-5 + # Names that MUST carry the deploy-time catalog, not the run-time + # one. These resources are created by `bundle deploy` using + # `${var.catalog}` (`--var catalog`), but CATALOG above is + # overridable at run time (`--params CATALOG=...`). Reconstructing + # the names from the run-time CATALOG widget breaks the moment the + # two dials disagree — the deploy created `-...` + # while the stage looks for `-...`. Baking the name + # into the param DEFAULT freezes it at deploy time. Stages read + # these params, with the legacy reconstruction kept as a fallback. + # + # Ops warehouse — DABs-managed + # (resources.sql_warehouses.caspers_ops_warehouse). + - name: OPS_WAREHOUSE_NAME + default: ${var.catalog}-ops-warehouse + # Agent App names — the Refund + Complaint agents are deployed as + # Databricks Apps (apps/refund-agent, apps/complaint-agent), not + # Model Serving endpoints. Consumers (ops dashboard, stream jobs, + # eval stages) reach them by App name via utils/agent_app_client.py. + # NOTE: App names are capped at 30 chars and sanitised by + # `_safe_app_name`; for long catalogs the stage re-sanitises this + # value, so the param and helper agree for normal catalogs. + - name: REFUND_AGENT_APP_NAME + default: refund-agent-${var.catalog} + - name: COMPLAINT_AGENT_APP_NAME + default: complaint-agent-${var.catalog} tasks: # ─── Data foundation ────────────────────────────────────────────── diff --git a/demos/dais2026-runbooks/MLflow.ipynb b/demos/dais2026-runbooks/MLflow.ipynb index d6811b1..36a61ed 100644 --- a/demos/dais2026-runbooks/MLflow.ipynb +++ b/demos/dais2026-runbooks/MLflow.ipynb @@ -280,7 +280,95 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Unity AI Gateway" + "### Unity AI Gateway\n", + "\n", + "Both Refund and Complaint route **every** internal LLM call through a single\n", + "gateway endpoint (`dais2026-ai-gateway`) so guardrails / inference table /\n", + "usage tracking / rate limits apply to every agent LLM call centrally.\n", + "Refund uses `databricks_langchain.ChatDatabricks(use_ai_gateway=True)`,\n", + "Complaint uses `dspy.LM('openai/...')` against `/ai-gateway/mlflow/v1`.\n", + "Both agents are **Databricks Apps** (`apps/refund-agent`,\n", + "`apps/complaint-agent`) running MLflow `AgentServer`, and gateway routing is\n", + "always on (no foundation-model fallback). Tokens refresh per request —\n", + "Refund via `DatabricksOpenAI`'s `BearerAuth`, Complaint by rebuilding its\n", + "`dspy.LM` and applying it with `dspy.context(...)`.\n", + "\n", + "**Show:**\n", + "- `AI Gateway` → `dais2026-ai-gateway` → **Permissions** tab → each agent\n", + " App's service principal (`refund-agent-`,\n", + " `complaint-agent-`) has CAN_QUERY. (App SPs are *not* in `account\n", + " users`, so they're granted directly — see SETUP.ipynb step 5.7.)\n", + "- `AI Gateway` → `dais2026-ai-gateway` → **Inference Tables** /\n", + " **Guardrails** / **Usage Tracking** / **Rate Limits** all enabled\n", + "- Refund prod experiment → recent trace → the `ChatDatabricks` LLM span has\n", + " `endpoint = /ai-gateway/mlflow/v1` (vs the old foundation-model\n", + " serving URL)\n", + "- Complaint prod experiment → recent trace → DSPy LM span has the same\n", + " gateway URL\n", + "\n", + "**Inference table — one row per agent LLM call:**\n", + "\n", + "```sql\n", + "-- Replace the catalog/schema with whatever you pointed inference tables at\n", + "-- in step 5 of SETUP.ipynb (default suggestion: .ai_gateway).\n", + "SELECT\n", + " event_time,\n", + " request_id,\n", + " status_code, -- 200 = served, 400 = blocked by guardrail\n", + " requester, -- the agent App's SP — attribute usage by agent\n", + " latency_ms,\n", + " CAST(request AS STRING) AS request_preview,\n", + " CAST(response AS STRING) AS response_preview\n", + "FROM ``.`ai_gateway`.`dais2026-ai-gateway_payload`\n", + "ORDER BY event_time DESC\n", + "LIMIT 20\n", + "```\n", + "\n", + "**Usage tracking — per-endpoint token spend (account-admin only):**\n", + "\n", + "```sql\n", + "SELECT\n", + " endpoint_name,\n", + " DATE_TRUNC('hour', usage_time) AS hour,\n", + " SUM(input_tokens) AS input_tokens,\n", + " SUM(output_tokens) AS output_tokens,\n", + " SUM(input_tokens + output_tokens) AS total_tokens\n", + "FROM system.ai_gateway.usage\n", + "WHERE endpoint_name = 'dais2026-ai-gateway'\n", + " AND usage_time >= CURRENT_TIMESTAMP - INTERVAL 1 HOUR\n", + "GROUP BY endpoint_name, hour\n", + "ORDER BY hour DESC\n", + "```\n", + "\n", + "**Guardrails in action — fire a PII complaint and watch it block:**\n", + "\n", + "```python\n", + "# Run in any notebook attached to a cluster. PII detection is set to\n", + "# BLOCK in the gateway config, so the request never reaches the LLM —\n", + "# the row in the inference table will show status_code = 400 and the\n", + "# guardrail label.\n", + "import requests\n", + "from databricks.sdk import WorkspaceClient\n", + "w = WorkspaceClient()\n", + "\n", + "resp = requests.post(\n", + " f\"{w.config.host.rstrip('/')}/ai-gateway/mlflow/v1/chat/completions\",\n", + " headers=w.config.authenticate(),\n", + " json={\n", + " \"model\": \"dais2026-ai-gateway\",\n", + " \"messages\": [{\"role\": \"user\", \"content\":\n", + " \"My order ORD-123 was late! My SSN is 123-45-6789, refund please.\"\n", + " }],\n", + " \"max_tokens\": 64,\n", + " },\n", + " timeout=30,\n", + ")\n", + "print(resp.status_code, resp.text[:200])\n", + "```\n", + "\n", + "**Message:** one gateway, two App-hosted agents, every LLM call governed —\n", + "guardrails stop PII before the model sees it, every request audited in UC,\n", + "usage attributable per agent App SP via the `requester` column." ] } ], diff --git a/demos/dais2026-runbooks/SETUP.ipynb b/demos/dais2026-runbooks/SETUP.ipynb index 57b03cc..371bcba 100644 --- a/demos/dais2026-runbooks/SETUP.ipynb +++ b/demos/dais2026-runbooks/SETUP.ipynb @@ -100,6 +100,158 @@ "print(result)" ] }, + { + "cell_type": "markdown", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "00000000-0000-0000-0000-000000000008", + "showTitle": false, + "title": "" + } + }, + "source": [ + "### 5. Set up the Unity AI Gateway endpoint (`all` target only)\n", + "\n", + "Both the Refund agent (LangGraph + `databricks_langchain.ChatDatabricks`)\n", + "and the Complaint agent (DSPy + `dspy.LM('openai/...')`) on the `all`\n", + "target route **every** internal LLM call through a single Unity AI Gateway\n", + "(v2 Beta) endpoint, so every agent LLM call gets PII guardrails,\n", + "inference-table audit, usage tracking, and rate limits applied centrally —\n", + "and you can attribute usage to either agent via the `requester` column on\n", + "the inference table.\n", + "\n", + "Both agents are deployed as **Databricks Apps** (`apps/refund-agent`,\n", + "`apps/complaint-agent`) running the MLflow `AgentServer`, not as Model\n", + "Serving endpoints. Gateway routing is **always on** — there is no\n", + "foundation-model fallback. Each App reads the gateway endpoint name from\n", + "its `AI_GATEWAY_ENDPOINT_NAME` env var (set from the job parameter of the\n", + "same name at deploy time) and POSTs to `/ai-gateway/mlflow/v1` with\n", + "that name as the request `model`.\n", + "\n", + "The v2 Beta gateway is **UI-configured only** — there is no public REST /\n", + "SDK API for creating, modifying, or granting permission on it (per\n", + "[Configure Unity AI Gateway endpoints](https://docs.databricks.com/aws/en/ai-gateway/configure-endpoints-beta)).\n", + "Do this manually in the workspace UI, **before** deploying the `all` target:\n", + "\n", + "1. **Enable the preview**\n", + "\n", + " Account console → **Previews** → toggle **Unity AI Gateway** on. (Account\n", + " admin only; skip if already enabled.)\n", + "\n", + "2. **Create the endpoint**\n", + "\n", + " Workspace sidebar → **AI Gateway** → **Create Unity AI Gateway Endpoint**.\n", + "\n", + " - **Name**: `dais2026-ai-gateway` (or pick another name — pass it via\n", + " `--params \"AI_GATEWAY_ENDPOINT_NAME=\"` at deploy time). The job\n", + " parameter defaults to `databricks-claude-sonnet-4-5`, so if you skip\n", + " this param the agents target a foundation model by that name directly;\n", + " set it to your governed endpoint name to get guardrails / audit.\n", + " - **Primary model**: a foundation model whose tool-use is good. We use\n", + " `databricks-claude-sonnet-4-5` to match the agents' default `LLM_MODEL`.\n", + " - Click **Create**.\n", + "\n", + "3. **Enable Inference Tables**\n", + "\n", + " Endpoint detail page → **Inference Tables** → **Edit** → enable.\n", + " Point it at a UC schema you control (e.g. `.ai_gateway`).\n", + " This populates `.ai_gateway._payload` with one row\n", + " per request — both allowed and blocked.\n", + "\n", + "4. **Enable Usage Tracking**\n", + "\n", + " Endpoint detail page → **Usage Tracking** → **Edit** → enable. Per-request\n", + " token counts land in `system.ai_gateway.usage` (account admins only —\n", + " if your user can't query that table, ask an account admin to grant\n", + " `SELECT ON SCHEMA system.ai_gateway` to your group).\n", + "\n", + "5. **Configure Guardrails**\n", + "\n", + " Endpoint detail page → **Guardrails** → **Edit**. Enable:\n", + " - **PII Detection** = **Block** (rejects requests containing SSNs, credit\n", + " cards, etc — returns HTTP 400 before the LLM sees them).\n", + " - **Jailbreak and Prompt Injection** = on.\n", + " - **Unsafe Content** = on.\n", + "\n", + "6. **Configure Rate Limits** (optional)\n", + "\n", + " Endpoint detail page → **Rate Limits** → **Edit**. Set a per-user QPM /\n", + " TPM limit if you want to demo the burst-test in `MLflow.ipynb` →\n", + " \"Unity AI Gateway\" section. Skip otherwise.\n", + "\n", + "7. **Grant CAN_QUERY to each agent App's service principal** *(after first deploy)*\n", + "\n", + " Each Databricks App runs as its own service principal, and that SP — not\n", + " `account users` — is what calls the gateway at request time. Unlike Model\n", + " Serving endpoints, an App's SP is **not** automatically a member of\n", + " `account users`, so a blanket `account users` grant will not cover it.\n", + " Grant `CAN_QUERY` to each agent App's SP directly:\n", + "\n", + " - First deploy the `all` target (step 8) so the two agent Apps —\n", + " `refund-agent-` and `complaint-agent-` — exist and\n", + " have SPs. Find each SP with:\n", + "\n", + " ```python\n", + " from databricks.sdk import WorkspaceClient\n", + " w = WorkspaceClient()\n", + " for app in (\"refund-agent-\", \"complaint-agent-\"):\n", + " a = w.apps.get(app)\n", + " print(app, \"→ SP:\", a.service_principal_client_id, a.service_principal_name)\n", + " ```\n", + "\n", + " - Then: AI Gateway → `dais2026-ai-gateway` → **Permissions** →\n", + " **Add user / group** → paste each App SP → **CAN_QUERY**.\n", + "\n", + " Why this is manual: v2 Beta gateway endpoints live on a separate API\n", + " surface from regular serving endpoints, with no permissions API, and they\n", + " can't be listed in `mlflow.models.resources.DatabricksServingEndpoint(...)`\n", + " (it crashes with `NOT_FOUND: Dependent serving endpoint does not\n", + " exist`). The agent stages therefore probe the gateway with the *notebook*\n", + " identity at deploy time and let the App validate it at startup with the\n", + " *App SP* — but the App SP can only succeed once you've granted it CAN_QUERY\n", + " here. If the grant is missing, the App's import-time gateway smoke-check\n", + " fails and the App will not start.\n", + "\n", + " > **First-deploy ordering:** the agent Apps' startup validation needs this\n", + " > grant, but the grant needs the Apps to exist. If the very first deploy\n", + " > fails at App startup with a gateway-auth error, grant CAN_QUERY to the two\n", + " > App SPs (now that they exist) and redeploy / restart the Apps.\n", + "\n", + "8. **Deploy the `all` target with the gateway wired in**\n", + "\n", + " ```bash\n", + " databricks bundle deploy -t all --var catalog=\n", + " databricks bundle run caspers -t all \\\n", + " --params \"CATALOG=,AI_GATEWAY_ENDPOINT_NAME=dais2026-ai-gateway\"\n", + " ```\n", + "\n", + " Pass the **same** catalog to `--var catalog` (deploy time) and\n", + " `--params CATALOG` (run time) — see the `--var` vs `--params` note in\n", + " `README.md` / `AGENTS.md`. The agent App names are baked at deploy time\n", + " into the `REFUND_AGENT_APP_NAME` / `COMPLAINT_AGENT_APP_NAME` job params,\n", + " so consumers (stream jobs, eval stages, ops dashboard) resolve the same\n", + " names the agent stages deployed even if the two catalogs ever drift.\n", + "\n", + " At request time both agents refresh their OAuth bearer automatically:\n", + " - Refund's `ChatDatabricks(use_ai_gateway=True)` authenticates through\n", + " `databricks_openai.DatabricksOpenAI`, whose httpx `BearerAuth` calls\n", + " `config.authenticate()` on every request — rotated M2M tokens are\n", + " picked up with no per-request rebuild.\n", + " - Complaint rebuilds its `dspy.LM` with a fresh bearer per request and\n", + " applies it via `dspy.context(lm=...)` (thread-safe on the AgentServer\n", + " worker threads — `dspy.configure()` is only called once at import).\n", + "\n", + "> **Verify:** after deploy, send one request to each agent App\n", + "> (`refund-agent-`, `complaint-agent-`) — the ops\n", + "> dashboard's agent tiles, or the stream jobs, will do this — and check that\n", + "> rows appear in `.ai_gateway._payload`. Each should show\n", + "> HTTP 200 with non-zero token counts and the App SP in the `requester`\n", + "> column. If you see HTTP 400 \"Invalid Token\" or the Apps fail to start,\n", + "> re-check the CAN_QUERY grant in step 7." + ] + }, { "cell_type": "markdown", "metadata": { diff --git a/jobs/complaint_agent_stream.ipynb b/jobs/complaint_agent_stream.ipynb index 903666c..d28a330 100644 --- a/jobs/complaint_agent_stream.ipynb +++ b/jobs/complaint_agent_stream.ipynb @@ -18,12 +18,15 @@ "import os\n", "import sys\n", "sys.path.append(os.path.abspath(\"../utils\"))\n", - "from agent_app_client import app_request_context, complaint_agent_app_name, extract_response_text\n", + "from agent_app_client import app_request_context, extract_response_text, resolve_agent_app_name\n", "\n", "try:\n", - " COMPLAINT_AGENT_APP_NAME = dbutils.widgets.get(\"COMPLAINT_AGENT_APP_NAME\")\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"COMPLAINT_AGENT_APP_NAME\")\n", "except Exception:\n", - " COMPLAINT_AGENT_APP_NAME = complaint_agent_app_name(CATALOG)\n", + " _APP_NAME_PARAM = \"\"\n", + "# Prefer the deploy-time-baked param so the name carries the deploy-time catalog\n", + "# even when --params CATALOG disagrees with --var catalog; resolver re-sanitises.\n", + "COMPLAINT_AGENT_APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"complaint\")\n", "\n", "_AGENT_APP_CONTEXT = app_request_context(app_name=COMPLAINT_AGENT_APP_NAME, dbutils=dbutils)\n", "COMPLAINT_AGENT_APP_URL = _AGENT_APP_CONTEXT[\"url\"]\n", @@ -82,39 +85,16 @@ "# - Refund + support streams use 50 because their agents are faster;\n", "# do not blindly copy that cap here without re-measuring.\n", "CHECKPOINT_PATH = f\"/Volumes/{CATALOG}/complaints/checkpoints/complaint_agent_stream\"\n", - "# Sized for the 600s task timeout. The complaint agent (DSPy ReAct) measured\n", - "# at ~38s per call \u2014 markedly slower than the refunder (~12s) because it does\n", - "# more tool-call iterations per complaint. 5 \u00d7 38 \u2248 190s leaves room for\n", - "# cold-start jitter, Delta write, and is_first_run() probes; 10 \u00d7 38 = 380s\n", + "# Sized for the 600s task timeout. The complaint agent measured around 25s\n", + "# per call through Apps + Gateway during smoke testing. 5 \u00d7 25 \u2248 125s leaves room for\n", + "# cold-start jitter, Delta write, and is_first_run() probes; 10 \u00d7 25 = 250s\n", "# + ~250s overhead was tipping over the 600s wall once the agent was\n", "# fully healthy.\n", "MAX_INFERENCES_PER_BATCH = 5\n", "\n", "\n", "def _extract_agent_text(response):\n", - " if not isinstance(response, dict):\n", - " return str(response)\n", - " direct = response.get(\"output_text\")\n", - " if isinstance(direct, str) and direct:\n", - " return direct\n", - " output = response.get(\"output\") or []\n", - " if isinstance(output, dict):\n", - " output = [output]\n", - " for item in output:\n", - " if not isinstance(item, dict):\n", - " continue\n", - " content = item.get(\"content\")\n", - " if isinstance(content, str) and content:\n", - " return content\n", - " if isinstance(content, dict):\n", - " content = [content]\n", - " if isinstance(content, list):\n", - " for content_item in content:\n", - " if isinstance(content_item, dict):\n", - " text = content_item.get(\"text\")\n", - " if isinstance(text, str) and text:\n", - " return text\n", - " raise ValueError(\"Could not extract final response text\")\n", + " return extract_response_text(response)\n", "\n", "\n", "def is_first_run():\n", @@ -428,4 +408,4 @@ }, "nbformat": 4, "nbformat_minor": 2 -} \ No newline at end of file +} diff --git a/jobs/refund_recommender_stream.ipynb b/jobs/refund_recommender_stream.ipynb index b9baeef..3306a58 100644 --- a/jobs/refund_recommender_stream.ipynb +++ b/jobs/refund_recommender_stream.ipynb @@ -37,12 +37,15 @@ "import os\n", "import sys\n", "sys.path.append(os.path.abspath(\"../utils\"))\n", - "from agent_app_client import app_request_context, extract_response_text, refund_agent_app_name\n", + "from agent_app_client import app_request_context, extract_response_text, resolve_agent_app_name\n", "\n", "try:\n", - " REFUND_AGENT_APP_NAME = dbutils.widgets.get(\"REFUND_AGENT_APP_NAME\")\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"REFUND_AGENT_APP_NAME\")\n", "except Exception:\n", - " REFUND_AGENT_APP_NAME = refund_agent_app_name(CATALOG)\n", + " _APP_NAME_PARAM = \"\"\n", + "# Prefer the deploy-time-baked param so the name carries the deploy-time catalog\n", + "# even when --params CATALOG disagrees with --var catalog; resolver re-sanitises.\n", + "REFUND_AGENT_APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"refund\")\n", "\n", "_AGENT_APP_CONTEXT = app_request_context(app_name=REFUND_AGENT_APP_NAME, dbutils=dbutils)\n", "REFUND_AGENT_APP_URL = _AGENT_APP_CONTEXT[\"url\"]\n", @@ -191,29 +194,7 @@ "\n", "\n", "def _extract_agent_text(response):\n", - " if not isinstance(response, dict):\n", - " return str(response)\n", - " direct = response.get(\"output_text\")\n", - " if isinstance(direct, str) and direct:\n", - " return direct\n", - " output = response.get(\"output\") or []\n", - " if isinstance(output, dict):\n", - " output = [output]\n", - " for item in output:\n", - " if not isinstance(item, dict):\n", - " continue\n", - " content = item.get(\"content\")\n", - " if isinstance(content, str) and content:\n", - " return content\n", - " if isinstance(content, dict):\n", - " content = [content]\n", - " if isinstance(content, list):\n", - " for content_item in content:\n", - " if isinstance(content_item, dict):\n", - " text = content_item.get(\"text\")\n", - " if isinstance(text, str) and text:\n", - " return text\n", - " raise ValueError(\"Could not extract final response text\")\n", + " return extract_response_text(response)\n", "\n", "# Per-call timeout (seconds) for the agent app. Without this, a wedged\n", "# app call can hang the UDF forever and pin the 10-min task timeout against\n", @@ -554,4 +535,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/stages/complaint_agent.ipynb b/stages/complaint_agent.ipynb index 157e486..fbe291a 100644 --- a/stages/complaint_agent.ipynb +++ b/stages/complaint_agent.ipynb @@ -6,7 +6,7 @@ "source": [ "#### complaint agent\n", "\n", - "Builds and ships an order-complaint agent using DSPy: author Unity Catalog tools, assemble the DSPy ReAct workflow in the Databricks App source, and deploy the app with MLflow AgentServer.\n" + "Builds and ships an order-complaint agent using DSPy: author Unity Catalog tools, assemble the Gateway-backed DSPy workflow in the Databricks App source, and deploy the app with MLflow AgentServer.\n" ], "id": "cell-0" }, @@ -232,7 +232,7 @@ "\n", "- Install orchestration dependencies and restart Python for a clean runtime.\n", "- Capture widget inputs (`CATALOG`, `LLM_MODEL`) and resolve the deterministic Databricks App name.\n", - "- Use `../apps/complaint-agent` as the source of truth for the DSPy ReAct complaint workflow.\n", + "- Use `../apps/complaint-agent` as the source of truth for the Gateway-backed DSPy complaint workflow.\n", "- Treat `LLM_MODEL` as a Unity AI Gateway endpoint name; no custom-agent LLM calls use legacy model-serving invocation routes.\n" ], "id": "dq5ml4wp6v" @@ -257,9 +257,26 @@ "\n", "import sys\n", "sys.path.append('../utils')\n", - "from agent_app_client import complaint_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", - "APP_NAME = complaint_agent_app_name(CATALOG)\n", + "\n", + "def _param(name, default=\"\"):\n", + " try:\n", + " return (dbutils.widgets.get(name) or \"\").strip() or default\n", + " except Exception:\n", + " return default\n", + "\n", + "\n", + "# Unity AI Gateway endpoint this agent routes every LLM call through. Distinct\n", + "# from LLM_MODEL (a foundation model used elsewhere); the param is only declared\n", + "# on the `all` target, so fall back to LLM_MODEL when it is absent.\n", + "AI_GATEWAY_ENDPOINT_NAME = _param(\"AI_GATEWAY_ENDPOINT_NAME\") or LLM_MODEL\n", + "\n", + "# App name from the deploy-time-baked COMPLAINT_AGENT_APP_NAME param so it\n", + "# carries the deploy-time catalog even when --params CATALOG disagrees with\n", + "# --var catalog; the resolver re-sanitises and falls back to deriving from\n", + "# CATALOG.\n", + "APP_NAME = resolve_agent_app_name(_param(\"COMPLAINT_AGENT_APP_NAME\"), CATALOG, \"complaint\")\n", "UC_MODEL_NAME = f\"{CATALOG}.ai.complaint_agent_app\"\n", "print(f\"Complaint agent app: {APP_NAME}\")\n" ], @@ -358,12 +375,12 @@ " _agent_py = f.read()\n", "\n", "_match = re.search(\n", - " r'class ComplaintTriage\\(dspy\\.Signature\\):\\s*\"\"\"(.*?)\"\"\"',\n", + " r'COMPLAINT_TRIAGE_PROMPT\\s*=\\s*\"\"\"(.*?)\"\"\"',\n", " _agent_py,\n", " re.DOTALL,\n", ")\n", "if not _match:\n", - " raise RuntimeError(\"Could not extract ComplaintTriage docstring from complaint app source\")\n", + " raise RuntimeError(\"Could not extract COMPLAINT_TRIAGE_PROMPT from complaint app source\")\n", "_current_complaint_prompt = _match.group(1).strip()\n", "\n", "_COMPLAINT_V1 = (\n", @@ -396,7 +413,7 @@ " \"stage\": \"complaint_agent\",\n", " \"app_name\": APP_NAME,\n", " \"uc_model\": UC_MODEL_NAME,\n", - " \"consumed_via\": \"DSPy Signature docstring in Databricks App source\",\n", + " \"consumed_via\": \"Gateway-backed DSPy prompt constant in Databricks App source\",\n", "}\n", "\n", "seed_prompt_history(\n", @@ -417,7 +434,7 @@ " ],\n", " current={\n", " \"template\": _current_complaint_prompt,\n", - " \"commit_message\": \"v3 (production): percentile-based credit calc + timing tools, deployed as Databricks App\",\n", + " \"commit_message\": \"v3 (production): percentile-based credit calc + UC lookups, deployed as Databricks App\",\n", " \"tags\": {**_common_tags, \"deployment_kind\": \"databricks_app\"},\n", " },\n", ")\n" @@ -449,18 +466,13 @@ "from databricks.sdk import WorkspaceClient\n", "from databricks.sdk.service import catalog as catalog_svc\n", "from databricks.sdk.service.apps import App, AppDeployment\n", - "from databricks.sdk.service.serving import (\n", - " ServingEndpointAccessControlRequest,\n", - " ServingEndpointPermissionLevel,\n", - ")\n", - "\n", "w = WorkspaceClient()\n", "source_code_path = os.path.abspath(\"../apps/complaint-agent\")\n", "print(f\"App name: {APP_NAME}\")\n", "print(f\"App source: {source_code_path}\")\n", "\n", - "gateway_chat_probe(llm_model=LLM_MODEL, w=w, dbutils=dbutils)\n", - "print(f\"Verified {LLM_MODEL} is queryable through Unity AI Gateway\")\n", + "gateway_chat_probe(llm_model=AI_GATEWAY_ENDPOINT_NAME, w=w, dbutils=dbutils)\n", + "print(f\"Verified {AI_GATEWAY_ENDPOINT_NAME} is queryable through Unity AI Gateway\")\n", "\n", "app_yaml_path = os.path.join(source_code_path, \"app.yaml\")\n", "app_yaml_contents = f\"\"\"command:\n", @@ -469,8 +481,8 @@ "env:\n", " - name: DATABRICKS_CATALOG\n", " value: '{CATALOG}'\n", - " - name: LLM_MODEL\n", - " value: '{LLM_MODEL}'\n", + " - name: AI_GATEWAY_ENDPOINT_NAME\n", + " value: '{AI_GATEWAY_ENDPOINT_NAME}'\n", " - name: MLFLOW_EXPERIMENT_ID\n", " value: '{prod_experiment_id}'\n", " - name: MLFLOW_TRACKING_URI\n", @@ -554,29 +566,15 @@ " except Exception as e:\n", " print(f\"Could not grant {privilege} on {full_name} to {app_uc_principal}: {e}\")\n", "\n", - "# LLM_MODEL names a Unity AI Gateway-backed endpoint. We only use the serving\n", - "# endpoint permissions API to grant CAN_QUERY on that Gateway endpoint; no LLM\n", - "# request is routed through legacy model-serving invocation routes.\n", - "llm_endpoint = None\n", - "try:\n", - " llm_endpoint = w.serving_endpoints.get(LLM_MODEL)\n", - "except Exception:\n", - " matches = [ep for ep in w.serving_endpoints.list() if ep.name == LLM_MODEL]\n", - " if matches:\n", - " llm_endpoint = matches[0]\n", - "if llm_endpoint is None or not getattr(llm_endpoint, \"id\", None):\n", - " raise RuntimeError(f\"Could not resolve Gateway endpoint {LLM_MODEL} for permission grant\")\n", - "\n", - "w.serving_endpoints.update_permissions(\n", - " serving_endpoint_id=llm_endpoint.id,\n", - " access_control_list=[\n", - " ServingEndpointAccessControlRequest(\n", - " service_principal_name=app_sp_id,\n", - " permission_level=ServingEndpointPermissionLevel.CAN_QUERY,\n", - " )\n", - " ],\n", - ")\n", - "print(f\"Granted CAN_QUERY on Gateway endpoint {LLM_MODEL} to app SP {app_sp_id}\")\n", + "# AI_GATEWAY_ENDPOINT_NAME is a Unity AI Gateway endpoint name. The stage probes\n", + "# it through /ai-gateway/mlflow/v1 above with the notebook identity, and the\n", + "# Databricks App validates the same endpoint at startup with the app service\n", + "# principal. Do not resolve it through the model-serving endpoint APIs: Beta\n", + "# Unity AI Gateway endpoints are queried through the Gateway API and need not\n", + "# exist as model-serving endpoint resources. NOTE: the v2 Beta gateway has no\n", + "# permissions API, so CAN_QUERY for the app SP must be granted manually in the\n", + "# AI Gateway UI (see runbook) \u2014 it cannot be granted from this stage.\n", + "print(f\"Gateway endpoint {AI_GATEWAY_ENDPOINT_NAME} will be validated by the app runtime at startup\")\n", "\n", "try:\n", " w.api_client.do(\n", @@ -728,4 +726,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/stages/complaint_agent_stream.ipynb b/stages/complaint_agent_stream.ipynb index c074f3d..7e78334 100644 --- a/stages/complaint_agent_stream.ipynb +++ b/stages/complaint_agent_stream.ipynb @@ -28,12 +28,16 @@ "import sys\n", "\n", "sys.path.append('../utils')\n", - "from agent_app_client import complaint_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", "w = WorkspaceClient()\n", "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "COMPLAINT_AGENT_APP_NAME = complaint_agent_app_name(CATALOG)\n", + "try:\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"COMPLAINT_AGENT_APP_NAME\")\n", + "except Exception:\n", + " _APP_NAME_PARAM = \"\"\n", + "COMPLAINT_AGENT_APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"complaint\")\n", "\n", "notebook_abs_path = os.path.abspath(\"../jobs/complaint_agent_stream\")\n", "notebook_dbx_path = notebook_abs_path.replace(\n", diff --git a/stages/complaint_evaluation.ipynb b/stages/complaint_evaluation.ipynb index a8bdec8..15ac664 100644 --- a/stages/complaint_evaluation.ipynb +++ b/stages/complaint_evaluation.ipynb @@ -36,9 +36,13 @@ "import os\n", "import sys\n", "sys.path.append(os.path.abspath(\"../utils\"))\n", - "from agent_app_client import complaint_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", - "APP_NAME = complaint_agent_app_name(CATALOG)\n", + "try:\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"COMPLAINT_AGENT_APP_NAME\")\n", + "except Exception:\n", + " _APP_NAME_PARAM = \"\"\n", + "APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"complaint\")\n", "\n", "try:\n", " SKIP_EVAL = dbutils.widgets.get(\"SKIP_EVAL\").strip().lower() == \"true\"\n", diff --git a/stages/operational_app.ipynb b/stages/operational_app.ipynb index 1ae9f5e..021450f 100644 --- a/stages/operational_app.ipynb +++ b/stages/operational_app.ipynb @@ -71,7 +71,14 @@ "# warehouse is missing, deploy hasn't been run (or the resource was removed).\n", "# Cleanup is also DABs-owned (`bundle destroy -t all`), so we no longer\n", "# register the warehouse with uc_state to avoid double-delete.\n", - "WAREHOUSE_NAME = f\"{CATALOG}-ops-warehouse\"\n", + "# Read from the OPS_WAREHOUSE_NAME job parameter (baked from ${var.catalog} at\n", + "# deploy time) so the name carries the deploy-time catalog even when\n", + "# --params CATALOG disagrees with --var catalog. Fall back to the legacy\n", + "# reconstruction for standalone-notebook runs where the widget is absent.\n", + "try:\n", + " WAREHOUSE_NAME = dbutils.widgets.get(\"OPS_WAREHOUSE_NAME\") or f\"{CATALOG}-ops-warehouse\"\n", + "except Exception:\n", + " WAREHOUSE_NAME = f\"{CATALOG}-ops-warehouse\"\n", "existing_wh = [wh for wh in w.warehouses.list() if wh.name == WAREHOUSE_NAME]\n", "if not existing_wh:\n", " raise RuntimeError(\n", @@ -591,7 +598,7 @@ "metadata": {}, "source": [ "import json as _json, re as _re, os as _os\n", - "from agent_app_client import refund_agent_app_name, complaint_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", "def _latest_from_uc_state(resource_type, key):\n", " try:\n", @@ -713,8 +720,18 @@ "lakebase_endpoint_path = OPS_ENDPOINT_PATH\n", "\n", "# Custom agent Apps \u2014 names are deterministic and catalog-scoped.\n", - "refund_agent_app = refund_agent_app_name(CATALOG)\n", - "complaint_agent_app = complaint_agent_app_name(CATALOG)\n", + "def _param(name, default=\"\"):\n", + " try:\n", + " return (dbutils.widgets.get(name) or \"\").strip() or default\n", + " except Exception:\n", + " return default\n", + "\n", + "\n", + "# Prefer the deploy-time-baked *_APP_NAME params so the dashboard embeds the\n", + "# same App names the agent stages deployed, even when --params CATALOG disagrees\n", + "# with --var catalog; the resolver re-sanitises and falls back to deriving.\n", + "refund_agent_app = resolve_agent_app_name(_param(\"REFUND_AGENT_APP_NAME\"), CATALOG, \"refund\")\n", + "complaint_agent_app = resolve_agent_app_name(_param(\"COMPLAINT_AGENT_APP_NAME\"), CATALOG, \"complaint\")\n", "\n", "def _app_url_or_empty(app_name):\n", " try:\n", diff --git a/stages/refund_evaluation.ipynb b/stages/refund_evaluation.ipynb index 589927e..c9913f2 100644 --- a/stages/refund_evaluation.ipynb +++ b/stages/refund_evaluation.ipynb @@ -37,9 +37,13 @@ "import os\n", "import sys\n", "sys.path.append(os.path.abspath(\"../utils\"))\n", - "from agent_app_client import refund_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", - "APP_NAME = refund_agent_app_name(CATALOG)\n", + "try:\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"REFUND_AGENT_APP_NAME\")\n", + "except Exception:\n", + " _APP_NAME_PARAM = \"\"\n", + "APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"refund\")\n", "\n", "try:\n", " SKIP_EVAL = dbutils.widgets.get(\"SKIP_EVAL\").strip().lower() == \"true\"\n", diff --git a/stages/refunder_agent.ipynb b/stages/refunder_agent.ipynb index 09ef7d0..2e411cd 100644 --- a/stages/refunder_agent.ipynb +++ b/stages/refunder_agent.ipynb @@ -331,9 +331,25 @@ "\n", "import sys\n", "sys.path.append('../utils')\n", - "from agent_app_client import refund_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", - "APP_NAME = refund_agent_app_name(CATALOG)\n", + "\n", + "def _param(name, default=\"\"):\n", + " try:\n", + " return (dbutils.widgets.get(name) or \"\").strip() or default\n", + " except Exception:\n", + " return default\n", + "\n", + "\n", + "# Unity AI Gateway endpoint this agent routes every LLM call through. Distinct\n", + "# from LLM_MODEL (a foundation model used elsewhere); the param is only declared\n", + "# on the `all` target, so fall back to LLM_MODEL when it is absent.\n", + "AI_GATEWAY_ENDPOINT_NAME = _param(\"AI_GATEWAY_ENDPOINT_NAME\") or LLM_MODEL\n", + "\n", + "# App name from the deploy-time-baked REFUND_AGENT_APP_NAME param so it carries\n", + "# the deploy-time catalog even when --params CATALOG disagrees with --var\n", + "# catalog; the resolver re-sanitises and falls back to deriving from CATALOG.\n", + "APP_NAME = resolve_agent_app_name(_param(\"REFUND_AGENT_APP_NAME\"), CATALOG, \"refund\")\n", "UC_MODEL_NAME = f\"{CATALOG}.ai.refund_agent_app\"\n", "print(f\"Refund agent app: {APP_NAME}\")\n" ] @@ -520,18 +536,13 @@ "from databricks.sdk import WorkspaceClient\n", "from databricks.sdk.service import catalog as catalog_svc\n", "from databricks.sdk.service.apps import App, AppDeployment\n", - "from databricks.sdk.service.serving import (\n", - " ServingEndpointAccessControlRequest,\n", - " ServingEndpointPermissionLevel,\n", - ")\n", - "\n", "w = WorkspaceClient()\n", "source_code_path = os.path.abspath(\"../apps/refund-agent\")\n", "print(f\"App name: {APP_NAME}\")\n", "print(f\"App source: {source_code_path}\")\n", "\n", - "gateway_chat_probe(llm_model=LLM_MODEL, w=w, dbutils=dbutils)\n", - "print(f\"Verified {LLM_MODEL} is queryable through Unity AI Gateway\")\n", + "gateway_chat_probe(llm_model=AI_GATEWAY_ENDPOINT_NAME, w=w, dbutils=dbutils)\n", + "print(f\"Verified {AI_GATEWAY_ENDPOINT_NAME} is queryable through Unity AI Gateway\")\n", "\n", "app_yaml_path = os.path.join(source_code_path, \"app.yaml\")\n", "app_yaml_contents = f\"\"\"command:\n", @@ -540,8 +551,8 @@ "env:\n", " - name: DATABRICKS_CATALOG\n", " value: '{CATALOG}'\n", - " - name: LLM_MODEL\n", - " value: '{LLM_MODEL}'\n", + " - name: AI_GATEWAY_ENDPOINT_NAME\n", + " value: '{AI_GATEWAY_ENDPOINT_NAME}'\n", " - name: MLFLOW_EXPERIMENT_ID\n", " value: '{prod_experiment_id}'\n", " - name: MLFLOW_TRACKING_URI\n", @@ -625,29 +636,15 @@ " except Exception as e:\n", " print(f\"Could not grant {privilege} on {full_name} to {app_uc_principal}: {e}\")\n", "\n", - "# LLM_MODEL names a Unity AI Gateway-backed endpoint. We only use the serving\n", - "# endpoint permissions API to grant CAN_QUERY on that Gateway endpoint; no LLM\n", - "# request is routed through legacy model-serving invocation routes.\n", - "llm_endpoint = None\n", - "try:\n", - " llm_endpoint = w.serving_endpoints.get(LLM_MODEL)\n", - "except Exception:\n", - " matches = [ep for ep in w.serving_endpoints.list() if ep.name == LLM_MODEL]\n", - " if matches:\n", - " llm_endpoint = matches[0]\n", - "if llm_endpoint is None or not getattr(llm_endpoint, \"id\", None):\n", - " raise RuntimeError(f\"Could not resolve Gateway endpoint {LLM_MODEL} for permission grant\")\n", - "\n", - "w.serving_endpoints.update_permissions(\n", - " serving_endpoint_id=llm_endpoint.id,\n", - " access_control_list=[\n", - " ServingEndpointAccessControlRequest(\n", - " service_principal_name=app_sp_id,\n", - " permission_level=ServingEndpointPermissionLevel.CAN_QUERY,\n", - " )\n", - " ],\n", - ")\n", - "print(f\"Granted CAN_QUERY on Gateway endpoint {LLM_MODEL} to app SP {app_sp_id}\")\n", + "# AI_GATEWAY_ENDPOINT_NAME is a Unity AI Gateway endpoint name. The stage probes\n", + "# it through /ai-gateway/mlflow/v1 above with the notebook identity, and the\n", + "# Databricks App validates the same endpoint at startup with the app service\n", + "# principal. Do not resolve it through the model-serving endpoint APIs: Beta\n", + "# Unity AI Gateway endpoints are queried through the Gateway API and need not\n", + "# exist as model-serving endpoint resources. NOTE: the v2 Beta gateway has no\n", + "# permissions API, so CAN_QUERY for the app SP must be granted manually in the\n", + "# AI Gateway UI (see runbook) \u2014 it cannot be granted from this stage.\n", + "print(f\"Gateway endpoint {AI_GATEWAY_ENDPOINT_NAME} will be validated by the app runtime at startup\")\n", "\n", "try:\n", " w.api_client.do(\n", @@ -870,4 +867,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} \ No newline at end of file +} diff --git a/stages/refunder_stream.ipynb b/stages/refunder_stream.ipynb index f8d17fd..2d4c6aa 100644 --- a/stages/refunder_stream.ipynb +++ b/stages/refunder_stream.ipynb @@ -75,12 +75,16 @@ "import sys\n", "\n", "sys.path.append('../utils')\n", - "from agent_app_client import refund_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", "w = WorkspaceClient()\n", "\n", "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", - "REFUND_AGENT_APP_NAME = refund_agent_app_name(CATALOG)\n", + "try:\n", + " _APP_NAME_PARAM = dbutils.widgets.get(\"REFUND_AGENT_APP_NAME\")\n", + "except Exception:\n", + " _APP_NAME_PARAM = \"\"\n", + "REFUND_AGENT_APP_NAME = resolve_agent_app_name(_APP_NAME_PARAM, CATALOG, \"refund\")\n", "\n", "notebook_abs_path = os.path.abspath(\"../jobs/refund_recommender_stream\")\n", "notebook_dbx_path = notebook_abs_path.replace(\n", diff --git a/utils/agent_app_client.py b/utils/agent_app_client.py index 3af7aaa..70dc783 100644 --- a/utils/agent_app_client.py +++ b/utils/agent_app_client.py @@ -6,7 +6,7 @@ import json import os import re -from typing import Any, Iterable +from typing import Any import requests from databricks.sdk import WorkspaceClient @@ -30,12 +30,26 @@ def _safe_app_name(value: str) -> str: return f"{prefix}-{digest}" +def resolve_agent_app_name(param_value: str, catalog: str, role: str) -> str: + """Resolve an agent App name. + + Prefers the deploy-time-baked job-param value (e.g. REFUND_AGENT_APP_NAME) + so the name carries the deploy-time catalog even when `--params CATALOG` + disagrees with `--var catalog`. Falls back to deriving + `-agent-` for standalone-notebook runs where the param is + absent. Always sanitised to the Databricks Apps name rules (idempotent, + so passing an already-resolved name back through is safe). + """ + raw = (param_value or "").strip() or f"{role}-agent-{catalog}" + return _safe_app_name(raw) + + def refund_agent_app_name(catalog: str) -> str: - return _safe_app_name(f"refund-agent-{catalog}") + return resolve_agent_app_name("", catalog, "refund") def complaint_agent_app_name(catalog: str) -> str: - return _safe_app_name(f"complaint-agent-{catalog}") + return resolve_agent_app_name("", catalog, "complaint") def get_notebook_token(dbutils: Any) -> str: @@ -166,20 +180,13 @@ def extract_response_text(response: Any) -> str: if not isinstance(response, dict): raise TypeError(f"Unsupported response type: {type(response).__name__}") - direct = response.get("output_text") - if isinstance(direct, str) and direct: - return direct - - for item in _iter_response_items(response.get("output", [])): - content = item.get("content") - if isinstance(content, str) and content: - return content - for content_item in _iter_response_items(content or []): - text = content_item.get("text") - if isinstance(text, str) and text: - return text - - raise ValueError(f"Could not extract output text from response keys: {sorted(response.keys())}") + try: + text = response["output"][0]["content"][0]["text"] + except (KeyError, IndexError, TypeError) as exc: + raise ValueError(f"Unexpected Responses API shape: {sorted(response.keys())}") from exc + if not isinstance(text, str) or not text: + raise ValueError("Responses API output text is empty") + return text def call_agent_app_text(**kwargs: Any) -> str: @@ -213,12 +220,3 @@ def gateway_chat_probe( timeout=timeout, ) response.raise_for_status() - - -def _iter_response_items(value: Any) -> Iterable[dict[str, Any]]: - if isinstance(value, dict): - yield value - elif isinstance(value, list): - for item in value: - if isinstance(item, dict): - yield item From 5715c2e2f8bb5811393217e8bf8581ac728ebd1b Mon Sep 17 00:00:00 2001 From: djliden <7102904+djliden@users.noreply.github.com> Date: Wed, 10 Jun 2026 09:14:44 -0500 Subject: [PATCH 3/3] Fix ops dashboard agent app access --- databricks.yml | 2 ++ jobs/complaint_agent_stream.ipynb | 18 +++++------ jobs/refund_recommender_stream.ipynb | 36 ++++++++++----------- stages/operational_app.ipynb | 48 ++++++++++++++++++++++++---- 4 files changed, 71 insertions(+), 33 deletions(-) diff --git a/databricks.yml b/databricks.yml index 51a419e..8a3b422 100644 --- a/databricks.yml +++ b/databricks.yml @@ -841,6 +841,8 @@ targets: - task_key: Operational_App max_retries: 0 depends_on: + - task_key: Refund_Recommender_Agent + - task_key: Complaint_Agent - task_key: Operational_Supervisor - task_key: Operational_Lakebase notebook_task: diff --git a/jobs/complaint_agent_stream.ipynb b/jobs/complaint_agent_stream.ipynb index d28a330..a873fcd 100644 --- a/jobs/complaint_agent_stream.ipynb +++ b/jobs/complaint_agent_stream.ipynb @@ -9,6 +9,15 @@ "This notebook streams complaints through the complaint agent for processing" ] }, + { + "cell_type": "code", + "metadata": {}, + "source": [ + "%pip install -U -qqqq databricks-sdk requests\n" + ], + "execution_count": null, + "outputs": [] + }, { "cell_type": "code", "metadata": {}, @@ -36,15 +45,6 @@ "execution_count": null, "outputs": [] }, - { - "cell_type": "code", - "metadata": {}, - "source": [ - "%pip install -U -qqqq databricks-sdk requests\n" - ], - "execution_count": null, - "outputs": [] - }, { "cell_type": "code", "metadata": {}, diff --git a/jobs/refund_recommender_stream.ipynb b/jobs/refund_recommender_stream.ipynb index 3306a58..80ae6f4 100644 --- a/jobs/refund_recommender_stream.ipynb +++ b/jobs/refund_recommender_stream.ipynb @@ -31,6 +31,24 @@ "title": "" } }, + "source": [ + "%pip install -U -qqqq databricks-sdk requests\n" + ], + "execution_count": 0, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "application/vnd.databricks.v1+cell": { + "cellMetadata": {}, + "inputWidgets": {}, + "nuid": "b2af6e4b-0a96-4810-9c15-eedb7e84db67", + "showTitle": false, + "tableResultSettingsMap": {}, + "title": "" + } + }, "source": [ "CATALOG = dbutils.widgets.get(\"CATALOG\")\n", "\n", @@ -55,24 +73,6 @@ "execution_count": 0, "outputs": [] }, - { - "cell_type": "code", - "metadata": { - "application/vnd.databricks.v1+cell": { - "cellMetadata": {}, - "inputWidgets": {}, - "nuid": "b2af6e4b-0a96-4810-9c15-eedb7e84db67", - "showTitle": false, - "tableResultSettingsMap": {}, - "title": "" - } - }, - "source": [ - "%pip install -U -qqqq databricks-sdk requests\n" - ], - "execution_count": 0, - "outputs": [] - }, { "cell_type": "code", "metadata": {}, diff --git a/stages/operational_app.ipynb b/stages/operational_app.ipynb index 021450f..1f8a435 100644 --- a/stages/operational_app.ipynb +++ b/stages/operational_app.ipynb @@ -53,7 +53,7 @@ "import sys, os\n", "sys.path.append('../utils')\n", "from uc_state import add\n", - "from agent_app_client import refund_agent_app_name, complaint_agent_app_name\n", + "from agent_app_client import resolve_agent_app_name\n", "\n", "from databricks.sdk import WorkspaceClient\n", "from databricks.sdk.service.apps import (\n", @@ -404,6 +404,27 @@ " print(f\"\\u274c FAILED to grant CAN_RUN on Genie space {space_id}: {e}\")\n", " return \"failed\"\n", "\n", + "def _grant_can_use_app(app_name: str, sp_id: str, retries: int = 3) -> bool:\n", + " \"\"\"Grant CAN_USE on a Databricks App to the ops dashboard app SP.\"\"\"\n", + " for attempt in range(retries):\n", + " try:\n", + " w.api_client.do(\n", + " \"PATCH\",\n", + " f\"/api/2.0/permissions/apps/{app_name}\",\n", + " body={\"access_control_list\": [\n", + " {\"service_principal_name\": sp_id, \"permission_level\": \"CAN_USE\"}\n", + " ]},\n", + " )\n", + " print(f\"\\u2705 Granted CAN_USE on App {app_name} to {sp_id}\")\n", + " return True\n", + " except Exception as e:\n", + " if attempt < retries - 1:\n", + " print(f\" Retry {attempt+1}/{retries} for App {app_name}: {e}\")\n", + " time.sleep(5 * (attempt + 1))\n", + " else:\n", + " print(f\"\\u274c FAILED to grant CAN_USE on App {app_name}: {e}\")\n", + " return False\n", + "\n", "# \u2500\u2500 Resolve endpoints + Genie spaces created BY THIS CATALOG from uc_state \u2500\u2500\u2500\u2500\n", "# P1-16: previously this cell scanned every mas-*/ka-* endpoint and every Genie\n", "# space in the workspace and granted to the app SP. Two pain points:\n", @@ -442,9 +463,15 @@ " if tile_id:\n", " owned_endpoint_names.add(f\"ka-{tile_id[:8]}-endpoint\")\n", "\n", + "def _job_param(name, default=\"\"):\n", + " try:\n", + " return (dbutils.widgets.get(name) or \"\").strip() or default\n", + " except Exception:\n", + " return default\n", + "\n", "owned_agent_app_names = {\n", - " refund_agent_app_name(CATALOG),\n", - " complaint_agent_app_name(CATALOG),\n", + " resolve_agent_app_name(_job_param(\"REFUND_AGENT_APP_NAME\"), CATALOG, \"refund\"),\n", + " resolve_agent_app_name(_job_param(\"COMPLAINT_AGENT_APP_NAME\"), CATALOG, \"complaint\"),\n", "}\n", "\n", "# Build the set of Genie spaces this catalog owns (`space_id`).\n", @@ -494,6 +521,14 @@ " if not ok:\n", " ep_failures.append(ep.name)\n", "\n", + "# \u2500\u2500 Grant CAN_USE on owned Refund / Complaint Apps \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", + "app_failures = []\n", + "print(f\"Granting CAN_USE on {len(owned_agent_app_names)} owned agent Apps\")\n", + "for app_name in sorted(owned_agent_app_names):\n", + " ok = _grant_can_use_app(app_name, app_sp_id)\n", + " if not ok:\n", + " app_failures.append(app_name)\n", + "\n", "# \u2500\u2500 Grant CAN_RUN on owned Genie spaces \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "genie_failures = []\n", "genie_orphans = []\n", @@ -512,14 +547,15 @@ "\n", "# \u2500\u2500 Fail loudly if any grants didn't land \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\n", "# Only real failures count \u2014 orphans are treated as warnings (stale uc_state).\n", - "if ep_failures or genie_failures:\n", + "if ep_failures or app_failures or genie_failures:\n", " raise RuntimeError(\n", " f\"Permission grants failed!\\n\"\n", " f\" Endpoints: {ep_failures}\\n\"\n", + " f\" Apps: {app_failures}\\n\"\n", " f\" Genie spaces: {genie_failures}\\n\"\n", " \"Fix the errors above and re-run this cell.\"\n", " )\n", - "print(\"\\n\\u2705 All endpoint and Genie space permissions granted successfully\")" + "print(\"\\n\\u2705 All endpoint, agent App, and Genie space permissions granted successfully\")" ], "execution_count": null, "outputs": [] @@ -899,4 +935,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} \ No newline at end of file +}